mirror of
https://github.com/eliasstepanik/vdo.ninja.git
synced 2026-01-11 21:58:35 +00:00
16301 lines
879 KiB
JavaScript
16301 lines
879 KiB
JavaScript
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
(function (global, factory) {
|
|
typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core')) :
|
|
typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core'], factory) :
|
|
(global = global || self, factory(global.tf = global.tf || {}, global.tf));
|
|
}(this, (function (exports, tf) { 'use strict';
|
|
|
|
/*! *****************************************************************************
|
|
Copyright (c) Microsoft Corporation.
|
|
|
|
Permission to use, copy, modify, and/or distribute this software for any
|
|
purpose with or without fee is hereby granted.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
|
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
|
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
|
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
|
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
|
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
|
PERFORMANCE OF THIS SOFTWARE.
|
|
***************************************************************************** */
|
|
/* global Reflect, Promise */
|
|
|
|
var extendStatics = function(d, b) {
|
|
extendStatics = Object.setPrototypeOf ||
|
|
({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||
|
|
function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; };
|
|
return extendStatics(d, b);
|
|
};
|
|
|
|
function __extends(d, b) {
|
|
extendStatics(d, b);
|
|
function __() { this.constructor = d; }
|
|
d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());
|
|
}
|
|
|
|
function __awaiter(thisArg, _arguments, P, generator) {
|
|
function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }
|
|
return new (P || (P = Promise))(function (resolve, reject) {
|
|
function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }
|
|
function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } }
|
|
function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }
|
|
step((generator = generator.apply(thisArg, _arguments || [])).next());
|
|
});
|
|
}
|
|
|
|
function __generator(thisArg, body) {
|
|
var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g;
|
|
return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g;
|
|
function verb(n) { return function (v) { return step([n, v]); }; }
|
|
function step(op) {
|
|
if (f) throw new TypeError("Generator is already executing.");
|
|
while (_) try {
|
|
if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;
|
|
if (y = 0, t) op = [op[0] & 2, t.value];
|
|
switch (op[0]) {
|
|
case 0: case 1: t = op; break;
|
|
case 4: _.label++; return { value: op[1], done: false };
|
|
case 5: _.label++; y = op[1]; op = [0]; continue;
|
|
case 7: op = _.ops.pop(); _.trys.pop(); continue;
|
|
default:
|
|
if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }
|
|
if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }
|
|
if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }
|
|
if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }
|
|
if (t[2]) _.ops.pop();
|
|
_.trys.pop(); continue;
|
|
}
|
|
op = body.call(thisArg, _);
|
|
} catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }
|
|
if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var contexts = {};
|
|
var WEBGL_ATTRIBUTES = {
|
|
alpha: false,
|
|
antialias: false,
|
|
premultipliedAlpha: false,
|
|
preserveDrawingBuffer: false,
|
|
depth: false,
|
|
stencil: false,
|
|
failIfMajorPerformanceCaveat: true
|
|
};
|
|
function setWebGLContext(webGLVersion, gl) {
|
|
contexts[webGLVersion] = gl;
|
|
}
|
|
function getWebGLContext(webGLVersion) {
|
|
if (!(webGLVersion in contexts)) {
|
|
var newCtx = getWebGLRenderingContext(webGLVersion);
|
|
if (newCtx !== null) {
|
|
contexts[webGLVersion] = newCtx;
|
|
}
|
|
else {
|
|
console.log('Could not get context for WebGL version', webGLVersion);
|
|
return null;
|
|
}
|
|
}
|
|
var gl = contexts[webGLVersion];
|
|
if (gl.isContextLost()) {
|
|
delete contexts[webGLVersion];
|
|
return getWebGLContext(webGLVersion);
|
|
}
|
|
gl.disable(gl.DEPTH_TEST);
|
|
gl.disable(gl.STENCIL_TEST);
|
|
gl.disable(gl.BLEND);
|
|
gl.disable(gl.DITHER);
|
|
gl.disable(gl.POLYGON_OFFSET_FILL);
|
|
gl.disable(gl.SAMPLE_COVERAGE);
|
|
gl.enable(gl.SCISSOR_TEST);
|
|
gl.enable(gl.CULL_FACE);
|
|
gl.cullFace(gl.BACK);
|
|
return contexts[webGLVersion];
|
|
}
|
|
function createCanvas(webGLVersion) {
|
|
if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) {
|
|
return new OffscreenCanvas(300, 150);
|
|
}
|
|
else if (typeof document !== 'undefined') {
|
|
return document.createElement('canvas');
|
|
}
|
|
else {
|
|
throw new Error('Cannot create a canvas in this context');
|
|
}
|
|
}
|
|
function getWebGLRenderingContext(webGLVersion) {
|
|
if (webGLVersion !== 1 && webGLVersion !== 2) {
|
|
throw new Error('Cannot get WebGL rendering context, WebGL is disabled.');
|
|
}
|
|
var canvas = createCanvas(webGLVersion);
|
|
canvas.addEventListener('webglcontextlost', function (ev) {
|
|
ev.preventDefault();
|
|
delete contexts[webGLVersion];
|
|
}, false);
|
|
if (webGLVersion === 1) {
|
|
return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) ||
|
|
canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES));
|
|
}
|
|
return canvas.getContext('webgl2', WEBGL_ATTRIBUTES);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PackingScheme;
|
|
(function (PackingScheme) {
|
|
/**
|
|
* All values in a single texel are densely packed without any constraints.
|
|
*
|
|
* This is how the shader encodes a tensor with shape = [2, 3, 4]
|
|
* (indices are [batch, row, col]).
|
|
*
|
|
* 000|001 010|011 020|021
|
|
* ------- ------- -------
|
|
* 002|003 012|013 022|023
|
|
*
|
|
* 100|101 110|111 120|121
|
|
* ------- ------- -------
|
|
* 102|103 112|113 122|123
|
|
*
|
|
*/
|
|
PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE";
|
|
/**
|
|
* Single texels contain only values from the same batch, and from adjacent
|
|
* rows and columns.
|
|
*
|
|
* This is how the shader encodes a tensor with shape = [2, 3, 5]
|
|
* (indices are [batch, row, col]).
|
|
*
|
|
* 000|001 002|003 004|xxx 020|021 022|023 024|xxx
|
|
* ------- ------- ------- ------- ------- -------
|
|
* 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
*
|
|
* 100|101 102|103 104|xxx 120|121 122|123 124|xxx
|
|
* ------- ------- ------- ------- ------- -------
|
|
* 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
*
|
|
*/
|
|
PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH";
|
|
})(PackingScheme || (PackingScheme = {}));
|
|
var TextureUsage;
|
|
(function (TextureUsage) {
|
|
TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER";
|
|
TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD";
|
|
TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS";
|
|
TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD";
|
|
})(TextureUsage || (TextureUsage = {}));
|
|
var PhysicalTextureType;
|
|
(function (PhysicalTextureType) {
|
|
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16";
|
|
PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32";
|
|
PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16";
|
|
})(PhysicalTextureType || (PhysicalTextureType = {}));
|
|
function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) {
|
|
return [columns, rows];
|
|
}
|
|
function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) {
|
|
return matrixSize * channelsPerTexture;
|
|
}
|
|
/**
|
|
* Get shape for densely packed RGBA texture.
|
|
*/
|
|
function getDenseTexShape(shape) {
|
|
var size = tf.util.sizeFromShape(shape);
|
|
var texelsNeeded = Math.ceil(size / 4);
|
|
return tf.util.sizeToSquarishShape(texelsNeeded);
|
|
}
|
|
function getPackedMatrixTextureShapeWidthHeight(rows, columns) {
|
|
return [
|
|
Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2))
|
|
];
|
|
}
|
|
function getPackedRGBAArraySizeFromMatrixShape(rows, columns) {
|
|
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
|
|
return w * h * 4;
|
|
}
|
|
function getTextureConfig(
|
|
// tslint:disable-next-line:no-any
|
|
gl, textureHalfFloatExtension) {
|
|
// tslint:disable-next-line:no-any
|
|
var glany = gl;
|
|
var internalFormatFloat;
|
|
var internalFormatHalfFloat;
|
|
var internalFormatPackedHalfFloat;
|
|
var internalFormatPackedFloat;
|
|
var textureFormatFloat;
|
|
var downloadTextureFormat;
|
|
var downloadUnpackNumChannels;
|
|
var defaultNumChannels;
|
|
var textureTypeHalfFloat;
|
|
var textureTypeFloat;
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 2) {
|
|
internalFormatFloat = glany.R32F;
|
|
internalFormatHalfFloat = glany.R16F;
|
|
internalFormatPackedHalfFloat = glany.RGBA16F;
|
|
internalFormatPackedFloat = glany.RGBA32F;
|
|
textureFormatFloat = glany.RED;
|
|
downloadUnpackNumChannels = 4;
|
|
defaultNumChannels = 1;
|
|
textureTypeHalfFloat = glany.HALF_FLOAT;
|
|
textureTypeFloat = glany.FLOAT;
|
|
}
|
|
else {
|
|
internalFormatFloat = gl.RGBA;
|
|
internalFormatHalfFloat = gl.RGBA;
|
|
internalFormatPackedHalfFloat = gl.RGBA;
|
|
internalFormatPackedFloat = glany.RGBA;
|
|
textureFormatFloat = gl.RGBA;
|
|
downloadUnpackNumChannels = 4;
|
|
defaultNumChannels = 4;
|
|
textureTypeHalfFloat = textureHalfFloatExtension != null ?
|
|
textureHalfFloatExtension.HALF_FLOAT_OES :
|
|
null;
|
|
textureTypeFloat = gl.FLOAT;
|
|
}
|
|
downloadTextureFormat = gl.RGBA;
|
|
return {
|
|
internalFormatFloat: internalFormatFloat,
|
|
internalFormatHalfFloat: internalFormatHalfFloat,
|
|
internalFormatPackedHalfFloat: internalFormatPackedHalfFloat,
|
|
internalFormatPackedFloat: internalFormatPackedFloat,
|
|
textureFormatFloat: textureFormatFloat,
|
|
downloadTextureFormat: downloadTextureFormat,
|
|
downloadUnpackNumChannels: downloadUnpackNumChannels,
|
|
defaultNumChannels: defaultNumChannels,
|
|
textureTypeHalfFloat: textureTypeHalfFloat,
|
|
textureTypeFloat: textureTypeFloat
|
|
};
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function callAndCheck(gl, func) {
|
|
var returnValue = func();
|
|
if (tf.env().getBool('DEBUG')) {
|
|
checkWebGLError(gl);
|
|
}
|
|
return returnValue;
|
|
}
|
|
function checkWebGLError(gl) {
|
|
var error = gl.getError();
|
|
if (error !== gl.NO_ERROR) {
|
|
throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error));
|
|
}
|
|
}
|
|
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format
|
|
var MIN_FLOAT16 = 5.96e-8;
|
|
var MAX_FLOAT16 = 65504;
|
|
function canBeRepresented(num) {
|
|
if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 ||
|
|
(MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
function getWebGLErrorMessage(gl, status) {
|
|
switch (status) {
|
|
case gl.NO_ERROR:
|
|
return 'NO_ERROR';
|
|
case gl.INVALID_ENUM:
|
|
return 'INVALID_ENUM';
|
|
case gl.INVALID_VALUE:
|
|
return 'INVALID_VALUE';
|
|
case gl.INVALID_OPERATION:
|
|
return 'INVALID_OPERATION';
|
|
case gl.INVALID_FRAMEBUFFER_OPERATION:
|
|
return 'INVALID_FRAMEBUFFER_OPERATION';
|
|
case gl.OUT_OF_MEMORY:
|
|
return 'OUT_OF_MEMORY';
|
|
case gl.CONTEXT_LOST_WEBGL:
|
|
return 'CONTEXT_LOST_WEBGL';
|
|
default:
|
|
return "Unknown error code " + status;
|
|
}
|
|
}
|
|
function getExtensionOrThrow(gl, extensionName) {
|
|
return throwIfNull(gl, function () { return gl.getExtension(extensionName); }, 'Extension "' + extensionName + '" not supported on this browser.');
|
|
}
|
|
function createVertexShader(gl, vertexShaderSource) {
|
|
var vertexShader = throwIfNull(gl, function () { return gl.createShader(gl.VERTEX_SHADER); }, 'Unable to create vertex WebGLShader.');
|
|
callAndCheck(gl, function () { return gl.shaderSource(vertexShader, vertexShaderSource); });
|
|
callAndCheck(gl, function () { return gl.compileShader(vertexShader); });
|
|
if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) {
|
|
console.log(gl.getShaderInfoLog(vertexShader));
|
|
throw new Error('Failed to compile vertex shader.');
|
|
}
|
|
return vertexShader;
|
|
}
|
|
function createFragmentShader(gl, fragmentShaderSource) {
|
|
var fragmentShader = throwIfNull(gl, function () { return gl.createShader(gl.FRAGMENT_SHADER); }, 'Unable to create fragment WebGLShader.');
|
|
callAndCheck(gl, function () { return gl.shaderSource(fragmentShader, fragmentShaderSource); });
|
|
callAndCheck(gl, function () { return gl.compileShader(fragmentShader); });
|
|
if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) {
|
|
logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader));
|
|
throw new Error('Failed to compile fragment shader.');
|
|
}
|
|
return fragmentShader;
|
|
}
|
|
var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g;
|
|
function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) {
|
|
var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog);
|
|
if (lineNumberRegexResult == null) {
|
|
console.log("Couldn't parse line number in error: " + shaderInfoLog);
|
|
console.log(shaderSource);
|
|
return;
|
|
}
|
|
var lineNumber = +lineNumberRegexResult[1];
|
|
var shaderLines = shaderSource.split('\n');
|
|
var pad = shaderLines.length.toString().length + 2;
|
|
var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) {
|
|
return tf.util.rightPad((lineNumber + 1).toString(), pad) + line;
|
|
});
|
|
var maxLineLength = 0;
|
|
for (var i = 0; i < linesWithLineNumbers.length; i++) {
|
|
maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength);
|
|
}
|
|
var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1);
|
|
var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber);
|
|
var afterErrorLines = linesWithLineNumbers.slice(lineNumber);
|
|
console.log(beforeErrorLines.join('\n'));
|
|
console.log(shaderInfoLog.split('\n')[0]);
|
|
console.log("%c " + tf.util.rightPad(errorLine[0], maxLineLength), 'border:1px solid red; background-color:#e3d2d2; color:#a61717');
|
|
console.log(afterErrorLines.join('\n'));
|
|
}
|
|
function createProgram(gl) {
|
|
return throwIfNull(gl, function () { return gl.createProgram(); }, 'Unable to create WebGLProgram.');
|
|
}
|
|
function linkProgram(gl, program) {
|
|
callAndCheck(gl, function () { return gl.linkProgram(program); });
|
|
if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) {
|
|
console.log(gl.getProgramInfoLog(program));
|
|
throw new Error('Failed to link vertex and fragment shaders.');
|
|
}
|
|
}
|
|
function validateProgram(gl, program) {
|
|
callAndCheck(gl, function () { return gl.validateProgram(program); });
|
|
if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) {
|
|
console.log(gl.getProgramInfoLog(program));
|
|
throw new Error('Shader program validation failed.');
|
|
}
|
|
}
|
|
function createStaticVertexBuffer(gl, data) {
|
|
var buffer = throwIfNull(gl, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer');
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); });
|
|
callAndCheck(gl, function () { return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW); });
|
|
return buffer;
|
|
}
|
|
function createStaticIndexBuffer(gl, data) {
|
|
var buffer = throwIfNull(gl, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer');
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer); });
|
|
callAndCheck(gl, function () { return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW); });
|
|
return buffer;
|
|
}
|
|
function getNumChannels() {
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 2) {
|
|
return 1;
|
|
}
|
|
return 4;
|
|
}
|
|
function createTexture(gl) {
|
|
return throwIfNull(gl, function () { return gl.createTexture(); }, 'Unable to create WebGLTexture.');
|
|
}
|
|
function validateTextureSize(width, height) {
|
|
var maxTextureSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
|
|
if ((width <= 0) || (height <= 0)) {
|
|
var requested = "[" + width + "x" + height + "]";
|
|
throw new Error('Requested texture size ' + requested + ' is invalid.');
|
|
}
|
|
if ((width > maxTextureSize) || (height > maxTextureSize)) {
|
|
var requested = "[" + width + "x" + height + "]";
|
|
var max = "[" + maxTextureSize + "x" + maxTextureSize + "]";
|
|
throw new Error('Requested texture size ' + requested +
|
|
' greater than WebGL maximum on this browser / GPU ' + max + '.');
|
|
}
|
|
}
|
|
function createFramebuffer(gl) {
|
|
return throwIfNull(gl, function () { return gl.createFramebuffer(); }, 'Unable to create WebGLFramebuffer.');
|
|
}
|
|
function bindVertexBufferToProgramAttribute(gl, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) {
|
|
var loc = gl.getAttribLocation(program, attribute);
|
|
if (loc === -1) {
|
|
// The GPU compiler decided to strip out this attribute because it's unused,
|
|
// thus no need to bind.
|
|
return false;
|
|
}
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); });
|
|
callAndCheck(gl, function () { return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes); });
|
|
callAndCheck(gl, function () { return gl.enableVertexAttribArray(loc); });
|
|
return true;
|
|
}
|
|
function bindTextureUnit(gl, texture, textureUnit) {
|
|
validateTextureUnit(gl, textureUnit);
|
|
callAndCheck(gl, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); });
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
|
|
}
|
|
function unbindTextureUnit(gl, textureUnit) {
|
|
validateTextureUnit(gl, textureUnit);
|
|
callAndCheck(gl, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); });
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
|
|
}
|
|
function getProgramUniformLocationOrThrow(gl, program, uniformName) {
|
|
return throwIfNull(gl, function () { return gl.getUniformLocation(program, uniformName); }, 'uniform "' + uniformName + '" not present in program.');
|
|
}
|
|
function getProgramUniformLocation(gl, program, uniformName) {
|
|
return gl.getUniformLocation(program, uniformName);
|
|
}
|
|
function bindTextureToProgramUniformSampler(gl, texture, uniformSamplerLocation, textureUnit) {
|
|
callAndCheck(gl, function () { return bindTextureUnit(gl, texture, textureUnit); });
|
|
callAndCheck(gl, function () { return gl.uniform1i(uniformSamplerLocation, textureUnit); });
|
|
}
|
|
function bindCanvasToFramebuffer(gl) {
|
|
callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); });
|
|
callAndCheck(gl, function () { return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); });
|
|
callAndCheck(gl, function () { return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height); });
|
|
}
|
|
function bindColorTextureToFramebuffer(gl, texture, framebuffer) {
|
|
callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); });
|
|
callAndCheck(gl, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); });
|
|
}
|
|
function unbindColorTextureFromFramebuffer(gl, framebuffer) {
|
|
callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); });
|
|
callAndCheck(gl, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0); });
|
|
}
|
|
function validateFramebuffer(gl) {
|
|
var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER);
|
|
if (status !== gl.FRAMEBUFFER_COMPLETE) {
|
|
throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status));
|
|
}
|
|
}
|
|
function getFramebufferErrorMessage(gl, status) {
|
|
switch (status) {
|
|
case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT:
|
|
return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT';
|
|
case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT:
|
|
return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT';
|
|
case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS:
|
|
return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS';
|
|
case gl.FRAMEBUFFER_UNSUPPORTED:
|
|
return 'FRAMEBUFFER_UNSUPPORTED';
|
|
default:
|
|
return "unknown error " + status;
|
|
}
|
|
}
|
|
function throwIfNull(gl, returnTOrNull, failureMessage) {
|
|
var tOrNull = callAndCheck(gl, function () { return returnTOrNull(); });
|
|
if (tOrNull == null) {
|
|
throw new Error(failureMessage);
|
|
}
|
|
return tOrNull;
|
|
}
|
|
function validateTextureUnit(gl, textureUnit) {
|
|
var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1;
|
|
var glTextureUnit = textureUnit + gl.TEXTURE0;
|
|
if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) {
|
|
var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]";
|
|
throw new Error("textureUnit must be in " + textureUnitRange + ".");
|
|
}
|
|
}
|
|
function getBatchDim(shape, dimsToSkip) {
|
|
if (dimsToSkip === void 0) { dimsToSkip = 2; }
|
|
return tf.util.sizeFromShape(shape.slice(0, shape.length - dimsToSkip));
|
|
}
|
|
function getRowsCols(shape) {
|
|
if (shape.length === 0) {
|
|
throw Error('Cannot get rows and columns of an empty shape array.');
|
|
}
|
|
return [
|
|
shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1]
|
|
];
|
|
}
|
|
function getShapeAs3D(shape) {
|
|
var shapeAs3D = [1, 1, 1];
|
|
var isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1);
|
|
if (!isScalar) {
|
|
shapeAs3D =
|
|
[getBatchDim(shape)].concat(getRowsCols(shape));
|
|
}
|
|
return shapeAs3D;
|
|
}
|
|
function getTextureShapeFromLogicalShape(logShape, isPacked) {
|
|
var _a;
|
|
if (isPacked === void 0) { isPacked = false; }
|
|
var maxTexSize = tf.env().getNumber('WEBGL_MAX_TEXTURE_SIZE');
|
|
if (isPacked) {
|
|
maxTexSize = maxTexSize * 2;
|
|
// This logic ensures we accurately count the number of packed texels needed
|
|
// to accommodate the tensor. We can only pack values in the same texel if
|
|
// they are from adjacent pairs of rows/cols within the same batch. So if a
|
|
// tensor has 3 rows, we pretend it has 4 rows in order to account for the
|
|
// fact that the texels containing the third row are half empty.
|
|
logShape = logShape.map(function (d, i) { return i >= logShape.length - 2 ?
|
|
tf.util.nearestLargerEven(logShape[i]) :
|
|
logShape[i]; });
|
|
// Packed texture height is at least 2 (the channel height of a single
|
|
// texel).
|
|
if (logShape.length === 1) {
|
|
logShape = [2, logShape[0]];
|
|
}
|
|
}
|
|
// If logical shape is 2, we don't squeeze, since we want to match physical.
|
|
if (logShape.length !== 2) {
|
|
var squeezeResult = tf.util.squeezeShape(logShape);
|
|
logShape = squeezeResult.newShape;
|
|
}
|
|
var size = tf.util.sizeFromShape(logShape);
|
|
if (logShape.length <= 1 && size <= maxTexSize) {
|
|
return [1, size];
|
|
}
|
|
else if (logShape.length === 2 && logShape[0] <= maxTexSize &&
|
|
logShape[1] <= maxTexSize) {
|
|
return logShape;
|
|
}
|
|
else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize &&
|
|
logShape[2] <= maxTexSize) {
|
|
return [logShape[0] * logShape[1], logShape[2]];
|
|
}
|
|
else if (logShape.length === 3 && logShape[0] <= maxTexSize &&
|
|
logShape[1] * logShape[2] <= maxTexSize) {
|
|
return [logShape[0], logShape[1] * logShape[2]];
|
|
}
|
|
else if (logShape.length === 4 &&
|
|
logShape[0] * logShape[1] * logShape[2] <= maxTexSize &&
|
|
logShape[3] <= maxTexSize) {
|
|
return [logShape[0] * logShape[1] * logShape[2], logShape[3]];
|
|
}
|
|
else if (logShape.length === 4 && logShape[0] <= maxTexSize &&
|
|
logShape[1] * logShape[2] * logShape[3] <= maxTexSize) {
|
|
return [logShape[0], logShape[1] * logShape[2] * logShape[3]];
|
|
}
|
|
else {
|
|
if (isPacked) {
|
|
// For packed textures size equals the number of channels required to
|
|
// accommodate the texture data. However in order to squarify such that
|
|
// inner dimensions stay even, we rewrite size to equal the number of
|
|
// texels. Then in the return statement we rehydrate the squarified
|
|
// dimensions to channel units.
|
|
var batchDim = getBatchDim(logShape);
|
|
var rows = 2, cols = 2;
|
|
if (logShape.length) {
|
|
_a = getRowsCols(logShape), rows = _a[0], cols = _a[1];
|
|
}
|
|
size = batchDim * (rows / 2) * (cols / 2);
|
|
return tf.util.sizeToSquarishShape(size).map(function (d) { return d * 2; });
|
|
}
|
|
return tf.util.sizeToSquarishShape(size);
|
|
}
|
|
}
|
|
function isEven(n) {
|
|
return n % 2 === 0;
|
|
}
|
|
/**
|
|
* This determines whether reshaping a packed texture requires rearranging
|
|
* the data within the texture, assuming 2x2 packing.
|
|
*/
|
|
function isReshapeFree(shape1, shape2) {
|
|
shape1 = shape1.slice(-2);
|
|
shape2 = shape2.slice(-2);
|
|
if (tf.util.arraysEqual(shape1, shape2)) {
|
|
return true;
|
|
}
|
|
if (!shape1.length || !shape2.length) { // One of the shapes is a scalar.
|
|
return true;
|
|
}
|
|
if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 ||
|
|
shape2[1] === 0) {
|
|
return true;
|
|
}
|
|
if (shape1.length !== shape2.length) { // One of the shapes is a vector.
|
|
var shape1Cols = shape1.slice(-1)[0];
|
|
var shape2Cols = shape2.slice(-1)[0];
|
|
if (shape1Cols === shape2Cols) {
|
|
return true;
|
|
}
|
|
if (isEven(shape1Cols) && isEven(shape2Cols) &&
|
|
(shape1[0] === 1 || shape2[0] === 1)) {
|
|
return true;
|
|
}
|
|
}
|
|
return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]);
|
|
}
|
|
// We cache webgl params because the environment gets reset between
|
|
// unit tests and we don't want to constantly query the WebGLContext for
|
|
// MAX_TEXTURE_SIZE.
|
|
var MAX_TEXTURE_SIZE;
|
|
var MAX_TEXTURES_IN_SHADER;
|
|
function getWebGLMaxTextureSize(webGLVersion) {
|
|
if (MAX_TEXTURE_SIZE == null) {
|
|
var gl = getWebGLContext(webGLVersion);
|
|
MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE);
|
|
}
|
|
return MAX_TEXTURE_SIZE;
|
|
}
|
|
function resetMaxTextureSize() {
|
|
MAX_TEXTURE_SIZE = null;
|
|
}
|
|
function resetMaxTexturesInShader() {
|
|
MAX_TEXTURES_IN_SHADER = null;
|
|
}
|
|
function getMaxTexturesInShader(webGLVersion) {
|
|
if (MAX_TEXTURES_IN_SHADER == null) {
|
|
var gl = getWebGLContext(webGLVersion);
|
|
MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS);
|
|
}
|
|
// We cap at 16 to avoid spurious runtime "memory exhausted" error.
|
|
return Math.min(16, MAX_TEXTURES_IN_SHADER);
|
|
}
|
|
function getWebGLDisjointQueryTimerVersion(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return 0;
|
|
}
|
|
var queryTimerVersion;
|
|
var gl = getWebGLContext(webGLVersion);
|
|
if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') &&
|
|
webGLVersion === 2) {
|
|
queryTimerVersion = 2;
|
|
}
|
|
else if (hasExtension(gl, 'EXT_disjoint_timer_query')) {
|
|
queryTimerVersion = 1;
|
|
}
|
|
else {
|
|
queryTimerVersion = 0;
|
|
}
|
|
return queryTimerVersion;
|
|
}
|
|
function hasExtension(gl, extensionName) {
|
|
var ext = gl.getExtension(extensionName);
|
|
return ext != null;
|
|
}
|
|
function isWebGLVersionEnabled(webGLVersion) {
|
|
try {
|
|
var gl = getWebGLContext(webGLVersion);
|
|
if (gl != null) {
|
|
return true;
|
|
}
|
|
}
|
|
catch (e) {
|
|
console.log('Error when getting WebGL context: ', e);
|
|
return false;
|
|
}
|
|
return false;
|
|
}
|
|
function isCapableOfRenderingToFloatTexture(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return false;
|
|
}
|
|
var gl = getWebGLContext(webGLVersion);
|
|
if (webGLVersion === 1) {
|
|
if (!hasExtension(gl, 'OES_texture_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
if (!hasExtension(gl, 'EXT_color_buffer_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
|
|
return isFrameBufferComplete;
|
|
}
|
|
/**
|
|
* Check if we can download values from a float/half-float texture.
|
|
*
|
|
* Note that for performance reasons we use binding a texture to a framebuffer
|
|
* as a proxy for ability to download float values later using readPixels. The
|
|
* texture params of this texture will not match those in readPixels exactly
|
|
* but if we are unable to bind some kind of float texture to the frameBuffer
|
|
* then we definitely will not be able to read float values from it.
|
|
*/
|
|
function isDownloadFloatTextureEnabled(webGLVersion) {
|
|
if (webGLVersion === 0) {
|
|
return false;
|
|
}
|
|
var gl = getWebGLContext(webGLVersion);
|
|
if (webGLVersion === 1) {
|
|
if (!hasExtension(gl, 'OES_texture_float')) {
|
|
return false;
|
|
}
|
|
if (!hasExtension(gl, 'WEBGL_color_buffer_float')) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
if (hasExtension(gl, 'EXT_color_buffer_float')) {
|
|
return createFloatTextureAndBindToFramebuffer(gl);
|
|
}
|
|
var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
|
|
if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
|
|
return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension);
|
|
}
|
|
return false;
|
|
}
|
|
var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function createFloatTextureAndBindToFramebuffer(gl) {
|
|
var texConfig = getTextureConfig(gl);
|
|
var texture = gl.createTexture();
|
|
gl.bindTexture(gl.TEXTURE_2D, texture);
|
|
var width = 1;
|
|
var height = 1;
|
|
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null);
|
|
var frameBuffer = gl.createFramebuffer();
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
|
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
|
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
|
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
gl.deleteTexture(texture);
|
|
gl.deleteFramebuffer(frameBuffer);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function createHalfFloatTextureAndBindToFramebuffer(
|
|
// tslint:disable-next-line:no-any
|
|
gl, textureHalfFloatExtension) {
|
|
var texConfig = getTextureConfig(gl, textureHalfFloatExtension);
|
|
var texture = gl.createTexture();
|
|
gl.bindTexture(gl.TEXTURE_2D, texture);
|
|
var width = 1;
|
|
var height = 1;
|
|
gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null);
|
|
var frameBuffer = gl.createFramebuffer();
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer);
|
|
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);
|
|
var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE;
|
|
gl.bindTexture(gl.TEXTURE_2D, null);
|
|
gl.bindFramebuffer(gl.FRAMEBUFFER, null);
|
|
gl.deleteTexture(texture);
|
|
gl.deleteFramebuffer(frameBuffer);
|
|
return isFrameBufferComplete;
|
|
}
|
|
function isWebGLFenceEnabled(webGLVersion) {
|
|
if (webGLVersion !== 2) {
|
|
return false;
|
|
}
|
|
var gl = getWebGLContext(webGLVersion);
|
|
// tslint:disable-next-line:no-any
|
|
var isEnabled = gl.fenceSync != null;
|
|
return isEnabled;
|
|
}
|
|
function assertNotComplex(tensor, opName) {
|
|
if (!Array.isArray(tensor)) {
|
|
tensor = [tensor];
|
|
}
|
|
tensor.forEach(function (t) {
|
|
if (t != null) {
|
|
tf.util.assert(t.dtype !== 'complex64', function () { return opName + " does not support complex64 tensors " +
|
|
'in the WebGL backend.'; });
|
|
}
|
|
});
|
|
}
|
|
|
|
var webgl_util = {
|
|
__proto__: null,
|
|
callAndCheck: callAndCheck,
|
|
canBeRepresented: canBeRepresented,
|
|
getWebGLErrorMessage: getWebGLErrorMessage,
|
|
getExtensionOrThrow: getExtensionOrThrow,
|
|
createVertexShader: createVertexShader,
|
|
createFragmentShader: createFragmentShader,
|
|
createProgram: createProgram,
|
|
linkProgram: linkProgram,
|
|
validateProgram: validateProgram,
|
|
createStaticVertexBuffer: createStaticVertexBuffer,
|
|
createStaticIndexBuffer: createStaticIndexBuffer,
|
|
getNumChannels: getNumChannels,
|
|
createTexture: createTexture,
|
|
validateTextureSize: validateTextureSize,
|
|
createFramebuffer: createFramebuffer,
|
|
bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute,
|
|
bindTextureUnit: bindTextureUnit,
|
|
unbindTextureUnit: unbindTextureUnit,
|
|
getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow,
|
|
getProgramUniformLocation: getProgramUniformLocation,
|
|
bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler,
|
|
bindCanvasToFramebuffer: bindCanvasToFramebuffer,
|
|
bindColorTextureToFramebuffer: bindColorTextureToFramebuffer,
|
|
unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer,
|
|
validateFramebuffer: validateFramebuffer,
|
|
getFramebufferErrorMessage: getFramebufferErrorMessage,
|
|
getBatchDim: getBatchDim,
|
|
getRowsCols: getRowsCols,
|
|
getShapeAs3D: getShapeAs3D,
|
|
getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape,
|
|
isReshapeFree: isReshapeFree,
|
|
getWebGLMaxTextureSize: getWebGLMaxTextureSize,
|
|
resetMaxTextureSize: resetMaxTextureSize,
|
|
resetMaxTexturesInShader: resetMaxTexturesInShader,
|
|
getMaxTexturesInShader: getMaxTexturesInShader,
|
|
getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion,
|
|
hasExtension: hasExtension,
|
|
isWebGLVersionEnabled: isWebGLVersionEnabled,
|
|
isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture,
|
|
isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled,
|
|
isWebGLFenceEnabled: isWebGLFenceEnabled,
|
|
assertNotComplex: assertNotComplex
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ENV = tf.env();
|
|
/**
|
|
* This file contains WebGL-specific flag registrations.
|
|
*/
|
|
/**
|
|
* True if WebGL is supported.
|
|
*/
|
|
ENV.registerFlag('HAS_WEBGL', function () { return ENV.getNumber('WEBGL_VERSION') > 0; });
|
|
/** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */
|
|
ENV.registerFlag('WEBGL_VERSION', function () {
|
|
if (isWebGLVersionEnabled(2)) {
|
|
return 2;
|
|
}
|
|
else if (isWebGLVersionEnabled(1)) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
});
|
|
/** Whether to check for numerical representation problems. */
|
|
ENV.registerFlag('WEBGL_CHECK_NUMERICAL_PROBLEMS', function () { return false; });
|
|
ENV.registerFlag('WEBGL_BUFFER_SUPPORTED', function () { return ENV.get('WEBGL_VERSION') === 2; });
|
|
/** Whether the WebGL backend will sometimes forward ops to the CPU. */
|
|
ENV.registerFlag('WEBGL_CPU_FORWARD', function () { return true; });
|
|
/** Whether the WebGL backend will always use f16 textures for rendering. */
|
|
ENV.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () { return false; });
|
|
/** Whether to turn all packing related flags on. */
|
|
ENV.registerFlag('WEBGL_PACK', function () { return ENV.getBool('HAS_WEBGL'); });
|
|
/** Whether we will pack the batchnormalization op. */
|
|
ENV.registerFlag('WEBGL_PACK_NORMALIZATION', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack the clip op. */
|
|
ENV.registerFlag('WEBGL_PACK_CLIP', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack the depthwise conv op. */
|
|
ENV.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack binary ops. */
|
|
ENV.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack unary ops. */
|
|
ENV.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack array ops. */
|
|
ENV.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack image ops. */
|
|
ENV.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will pack reduce ops. */
|
|
ENV.registerFlag('WEBGL_PACK_REDUCE', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether packed WebGL kernels lazily unpack their outputs. */
|
|
ENV.registerFlag('WEBGL_LAZILY_UNPACK', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** Whether we will use the im2col algorithm to speed up convolutions. */
|
|
ENV.registerFlag('WEBGL_CONV_IM2COL', function () { return ENV.getBool('WEBGL_PACK'); });
|
|
/** The maximum texture dimension. */
|
|
ENV.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () { return getWebGLMaxTextureSize(ENV.getNumber('WEBGL_VERSION')); });
|
|
/** The maximum texture dimension. */
|
|
ENV.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () { return getMaxTexturesInShader(ENV.getNumber('WEBGL_VERSION')); });
|
|
/**
|
|
* The disjoint_query_timer extension version.
|
|
* 0: disabled, 1: EXT_disjoint_timer_query, 2:
|
|
* EXT_disjoint_timer_query_webgl2.
|
|
* In Firefox with WebGL 2.0,
|
|
* EXT_disjoint_timer_query_webgl2 is not available, so we must use the
|
|
* WebGL 1.0 extension.
|
|
*/
|
|
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () {
|
|
var webGLVersion = ENV.getNumber('WEBGL_VERSION');
|
|
if (webGLVersion === 0) {
|
|
return 0;
|
|
}
|
|
return getWebGLDisjointQueryTimerVersion(webGLVersion);
|
|
});
|
|
/**
|
|
* Whether the timer object from the disjoint_query_timer extension gives
|
|
* timing information that is reliable.
|
|
*/
|
|
ENV.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () { return ENV.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 &&
|
|
!tf.device_util.isMobile(); });
|
|
/**
|
|
* Whether the device is physically capable of rendering to float32 textures.
|
|
*/
|
|
ENV.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () { return isCapableOfRenderingToFloatTexture(ENV.getNumber('WEBGL_VERSION')); });
|
|
/**
|
|
* Whether rendering to float32 textures is enabled. If disabled, renders to
|
|
* float16 textures.
|
|
*/
|
|
ENV.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () {
|
|
return ENV.getBool('WEBGL_FORCE_F16_TEXTURES') ?
|
|
false :
|
|
ENV.getBool('WEBGL_RENDER_FLOAT32_CAPABLE');
|
|
});
|
|
/**
|
|
* Whether downloading float textures is enabled (16 or 32 bit). If disabled,
|
|
* uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading.
|
|
*/
|
|
ENV.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () { return isDownloadFloatTextureEnabled(ENV.getNumber('WEBGL_VERSION')); });
|
|
/** Whether the fence API is available. */
|
|
ENV.registerFlag('WEBGL_FENCE_API_ENABLED', function () { return isWebGLFenceEnabled(ENV.getNumber('WEBGL_VERSION')); });
|
|
/**
|
|
* Tensors with size <= than this will be uploaded as uniforms, not textures.
|
|
*/
|
|
ENV.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () {
|
|
// Use uniform uploads only when 32bit floats are supported. In
|
|
// 16bit
|
|
// environments there are problems with comparing a 16bit texture value
|
|
// with a 32bit uniform value.
|
|
var useUniforms = ENV.getBool('WEBGL_RENDER_FLOAT32_ENABLED');
|
|
return useUniforms ? 4 : 0;
|
|
});
|
|
/**
|
|
* If the total number of bytes allocated on the GPU is greater than this
|
|
* number, we will aggressively delete textures upon disposal with
|
|
* gl.deleteMatrixTexture, rather than making them available for reuse.
|
|
*
|
|
* Default value -1 indicates that we will never aggressively delete textures.
|
|
*/
|
|
ENV.registerFlag('WEBGL_DELETE_TEXTURE_THRESHOLD', function () {
|
|
return -1;
|
|
}, function (threshold) {
|
|
if (threshold < 0 && threshold !== -1) {
|
|
throw new Error("WEBGL_DELETE_TEXTURE_THRESHOLD must be -1 (indicating never " +
|
|
("delete) or at least 0, but got " + threshold + "."));
|
|
}
|
|
});
|
|
/**
|
|
* Trigger a manual GL command flush if the threshold of time has passed since
|
|
* previous Kernel execution. This can be useful for Andorid device where GL
|
|
* command flush are delayed un til the end of javascript task. This value is
|
|
* measured in millisecond. Typically you want to set this value to close to 1.
|
|
*
|
|
* Default value 1 for mobile chrome, and -1 for rest cases. -1 indicates that
|
|
* we will not enforce manual flush and depend on system default flush schedule.
|
|
*/
|
|
ENV.registerFlag('WEBGL_FLUSH_THRESHOLD', function () {
|
|
return tf.device_util.isMobile() && ENV.getBool('IS_CHROME') ? 1 : -1;
|
|
}, function (threshold) {
|
|
if (threshold < 0 && threshold !== -1) {
|
|
throw new Error("WEBGL_FLUSH_THRESHOLD must be -1 (indicating never " +
|
|
("manual flush) or at least 0, but got " + threshold + "."));
|
|
}
|
|
});
|
|
/**
|
|
* Threshold for input tensor size that determines whether WebGL backend will
|
|
* delegate computation to CPU.
|
|
*
|
|
* Default value is 128.
|
|
*/
|
|
ENV.registerFlag('CPU_HANDOFF_SIZE_THRESHOLD', function () { return 128; });
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function getGlslDifferences() {
|
|
var version;
|
|
var attribute;
|
|
var varyingVs;
|
|
var varyingFs;
|
|
var texture2D;
|
|
var output;
|
|
var defineOutput;
|
|
var defineSpecialNaN;
|
|
var defineSpecialInf;
|
|
var defineRound;
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 2) {
|
|
version = '#version 300 es';
|
|
attribute = 'in';
|
|
varyingVs = 'out';
|
|
varyingFs = 'in';
|
|
texture2D = 'texture';
|
|
output = 'outputColor';
|
|
defineOutput = 'out vec4 outputColor;';
|
|
// Use custom isnan definition to work across differences between
|
|
// implementations on various platforms. While this should happen in ANGLE
|
|
// we still see differences between android and windows (on chrome) when
|
|
// using isnan directly.
|
|
defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n ";
|
|
// In webgl 2 we do not need to specify a custom isinf so there is no
|
|
// need for a special INFINITY constant.
|
|
defineSpecialInf = "";
|
|
defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
|
|
}
|
|
else {
|
|
version = '';
|
|
attribute = 'attribute';
|
|
varyingVs = 'varying';
|
|
varyingFs = 'varying';
|
|
texture2D = 'texture2D';
|
|
output = 'gl_FragColor';
|
|
defineOutput = '';
|
|
// WebGL1 has no built in isnan so we define one here.
|
|
defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n ";
|
|
defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n ";
|
|
defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n ";
|
|
}
|
|
return {
|
|
version: version,
|
|
attribute: attribute,
|
|
varyingVs: varyingVs,
|
|
varyingFs: varyingFs,
|
|
texture2D: texture2D,
|
|
output: output,
|
|
defineOutput: defineOutput,
|
|
defineSpecialNaN: defineSpecialNaN,
|
|
defineSpecialInf: defineSpecialInf,
|
|
defineRound: defineRound
|
|
};
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* Produces GLSL code that derives logical coordinates from a flat
|
|
* index. The code performs integer division with each stride and decrements
|
|
* the index until the index equals the final dimension coordinate.
|
|
*/
|
|
function getLogicalCoordinatesFromFlatIndex(coords, shape, index) {
|
|
if (index === void 0) { index = 'index'; }
|
|
var strides = tf.util.computeStrides(shape);
|
|
return strides
|
|
.map(function (stride, i) {
|
|
var line1 = "int " + coords[i] + " = " + index + " / " + stride;
|
|
var line2 = i === strides.length - 1 ?
|
|
"int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + stride :
|
|
"index -= " + coords[i] + " * " + stride;
|
|
return line1 + "; " + line2 + ";";
|
|
})
|
|
.join('');
|
|
}
|
|
/**
|
|
* Produces GLSL that computes the flat index from 3D coordinates.
|
|
*/
|
|
function getFlatIndexFrom3D(shape) {
|
|
var strides = tf.util.computeStrides(shape).map(function (d) { return d.toString(); });
|
|
return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n";
|
|
}
|
|
var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n";
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DecodeMatrixProgram = /** @class */ (function () {
|
|
function DecodeMatrixProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
this.outPackingScheme = PackingScheme.DENSE;
|
|
var texShape = getDenseTexShape(outputShape);
|
|
var glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n ";
|
|
}
|
|
return DecodeMatrixProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DecodeMatrixPackedProgram = /** @class */ (function () {
|
|
function DecodeMatrixPackedProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outPackingScheme = PackingScheme.DENSE;
|
|
var texShape = getDenseTexShape(outputShape);
|
|
var glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n ";
|
|
}
|
|
return DecodeMatrixPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var EncodeFloatProgram = /** @class */ (function () {
|
|
function EncodeFloatProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.outTexUsage = TextureUsage.DOWNLOAD;
|
|
var glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n ";
|
|
}
|
|
return EncodeFloatProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var EncodeFloatPackedProgram = /** @class */ (function () {
|
|
function EncodeFloatPackedProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = false;
|
|
this.outTexUsage = TextureUsage.DOWNLOAD;
|
|
var glsl = getGlslDifferences();
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n ";
|
|
}
|
|
return EncodeFloatPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var EncodeMatrixProgram = /** @class */ (function () {
|
|
function EncodeMatrixProgram(outputShape, texShape, inputIsUnsignedByte) {
|
|
if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; }
|
|
this.variableNames = ['A'];
|
|
var glsl = getGlslDifferences();
|
|
var height = texShape[0], width = texShape[1];
|
|
this.outputShape = outputShape;
|
|
var output = "result";
|
|
if (inputIsUnsignedByte) {
|
|
output = "floor(result * 255. + 0.5)";
|
|
}
|
|
this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n int r = flatIndex / " + width + ";\n int c = imod(flatIndex, " + width + ");\n vec2 uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n ";
|
|
}
|
|
return EncodeMatrixProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/*
|
|
This is how the shader encodes a tensor with shape = [2, 3, 5]
|
|
(indices are [batch, row, col]).
|
|
|
|
000|001 002|003 004|xxx 020|021 022|023 024|xxx
|
|
------- ------- ------- ------- ------- -------
|
|
010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
|
|
100|101 102|103 104|xxx 120|121 122|123 124|xxx
|
|
------- ------- ------- ------- ------- -------
|
|
110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx
|
|
|
|
Single texels contain only values from the same batch, and from adjacent rows
|
|
and columns.
|
|
*/
|
|
var EncodeMatrixPackedProgram = /** @class */ (function () {
|
|
function EncodeMatrixPackedProgram(outputShape, texShape, inputIsUnsignedByte) {
|
|
if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; }
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
var glsl = getGlslDifferences();
|
|
var height = texShape[0], width = texShape[1];
|
|
this.outputShape = outputShape;
|
|
var mainLoop = '';
|
|
var output = 'result';
|
|
if (inputIsUnsignedByte) {
|
|
output = 'floor(result * 255. + 0.5)';
|
|
}
|
|
for (var row = 0; row <= 1; row++) {
|
|
for (var col = 0; col <= 1; col++) {
|
|
var channel = row * 2 + col;
|
|
mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + outputShape[2] + ") {\n localCoords[2] += " + col + ";\n if(localCoords[1] + " + row + " < " + outputShape[1] + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / " + width + ";\n c = imod(flatIndex, " + width + ");\n uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n values = " + glsl.texture2D + "(A, uv);\n\n if(offset == 0) {\n result[" + channel + "] = values[0];\n } else if(offset == 1) {\n result[" + channel + "] = values[1];\n } else if(offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n ";
|
|
}
|
|
}
|
|
this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n ";
|
|
}
|
|
return EncodeMatrixPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function createVertexShader$1(gl) {
|
|
var glsl = getGlslDifferences();
|
|
var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }";
|
|
return createVertexShader(gl, vertexShaderSource);
|
|
}
|
|
function createVertexBuffer(gl) {
|
|
// [x y z u v] * [upper-left, lower-left, upper-right, lower-right]
|
|
var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]);
|
|
return createStaticVertexBuffer(gl, vertexArray);
|
|
}
|
|
function createIndexBuffer(gl) {
|
|
// OpenGL (and WebGL) have "CCW == front" winding
|
|
var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]);
|
|
return createStaticIndexBuffer(gl, triangleVertexIndices);
|
|
}
|
|
function createAndConfigureTexture(gl, width, height, internalFormat, textureFormat, textureType) {
|
|
validateTextureSize(width, height);
|
|
var texture = createTexture(gl);
|
|
var tex2d = gl.TEXTURE_2D;
|
|
callAndCheck(gl, function () { return gl.bindTexture(tex2d, texture); });
|
|
callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); });
|
|
callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); });
|
|
callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST); });
|
|
callAndCheck(gl, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST); });
|
|
callAndCheck(gl, function () { return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null); });
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
|
|
return texture;
|
|
}
|
|
function getInternalFormatForFloat32MatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatFloat;
|
|
}
|
|
function createFloat32MatrixTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat32MatrixTexture(textureConfig), textureConfig.textureFormatFloat, gl.FLOAT);
|
|
}
|
|
function getInternalFormatForFloat16MatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatHalfFloat;
|
|
}
|
|
function createFloat16MatrixTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16MatrixTexture(textureConfig), textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat);
|
|
}
|
|
function getInternalFormatForUnsignedBytesMatrixTexture(textureConfig) {
|
|
return textureConfig.downloadTextureFormat;
|
|
}
|
|
function createUnsignedBytesMatrixTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForUnsignedBytesMatrixTexture(textureConfig), gl.RGBA, gl.UNSIGNED_BYTE);
|
|
}
|
|
function getInternalFormatForPackedMatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatPackedFloat;
|
|
}
|
|
function createPackedMatrixTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForPackedMatrixTexture(textureConfig), gl.RGBA, gl.FLOAT);
|
|
}
|
|
function getInternalFormatForFloat16PackedMatrixTexture(textureConfig) {
|
|
return textureConfig.internalFormatPackedHalfFloat;
|
|
}
|
|
function createFloat16PackedMatrixTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
return createAndConfigureTexture(gl, width, height, getInternalFormatForFloat16PackedMatrixTexture(textureConfig), gl.RGBA, textureConfig.textureTypeHalfFloat);
|
|
}
|
|
function bindVertexProgramAttributeStreams(gl, program, vertexBuffer) {
|
|
var posOffset = 0; // x is the first buffer element
|
|
var uvOffset = 3 * 4; // uv comes after [x y z]
|
|
var stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float.
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); });
|
|
var success = bindVertexBufferToProgramAttribute(gl, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset);
|
|
return success &&
|
|
bindVertexBufferToProgramAttribute(gl, program, 'uv', vertexBuffer, 2, stride, uvOffset);
|
|
}
|
|
function uploadDenseMatrixToTexture(gl, texture, width, height, data, textureConfig) {
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
|
|
var dataForUpload, texelDataType, internalFormat;
|
|
if (data instanceof Uint8Array) {
|
|
dataForUpload = new Uint8Array(width * height * 4);
|
|
texelDataType = gl.UNSIGNED_BYTE;
|
|
internalFormat = gl.RGBA;
|
|
}
|
|
else {
|
|
dataForUpload = new Float32Array(width * height * 4);
|
|
texelDataType = gl.FLOAT;
|
|
internalFormat = textureConfig.internalFormatPackedFloat;
|
|
}
|
|
dataForUpload.set(data);
|
|
callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload); });
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
|
|
}
|
|
function uploadPixelDataToTexture(gl, texture, pixels) {
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); });
|
|
if (pixels.data instanceof Uint8Array) {
|
|
callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); });
|
|
}
|
|
else {
|
|
callAndCheck(gl, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels); });
|
|
}
|
|
callAndCheck(gl, function () { return gl.bindTexture(gl.TEXTURE_2D, null); });
|
|
}
|
|
function createBufferFromOutputTexture(gl2, rows, columns, textureConfig) {
|
|
// Create and bind the buffer.
|
|
var buffer = gl2.createBuffer();
|
|
callAndCheck(gl2, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); });
|
|
// Initialize the buffer to the size of the texture in bytes.
|
|
var bytesPerFloat = 4;
|
|
var valuesPerTexel = 4;
|
|
var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns;
|
|
callAndCheck(gl2, function () { return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ); });
|
|
// Enqueue a command on the GPU command queue to copy of texture into the
|
|
// buffer.
|
|
callAndCheck(gl2, function () { return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0); });
|
|
callAndCheck(gl2, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); });
|
|
return buffer;
|
|
}
|
|
function downloadFloat32MatrixFromBuffer(gl, buffer, size) {
|
|
var gl2 = gl;
|
|
var downloadTarget = new Float32Array(size);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
|
|
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
|
|
return downloadTarget;
|
|
}
|
|
function downloadByteEncodedFloatMatrixFromOutputTexture(gl, rows, columns, textureConfig) {
|
|
var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1];
|
|
var numChannels = 4;
|
|
var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels));
|
|
callAndCheck(gl, function () { return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget); });
|
|
// By wrapping the buffer in a Float32Array, we use native browser IEEE 754
|
|
// decoding of the 4 bytes that back each 32 bit float.
|
|
return new Float32Array(downloadTarget.buffer);
|
|
}
|
|
function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) {
|
|
var gl2 = gl;
|
|
var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols));
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer);
|
|
gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget);
|
|
gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null);
|
|
return downloadTarget;
|
|
}
|
|
function downloadMatrixFromPackedOutputTexture(gl, physicalRows, physicalCols) {
|
|
var packedRGBA = new Float32Array(physicalRows * physicalCols * 4);
|
|
callAndCheck(gl, function () { return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA); });
|
|
return packedRGBA;
|
|
}
|
|
|
|
var gpgpu_util = {
|
|
__proto__: null,
|
|
createVertexShader: createVertexShader$1,
|
|
createVertexBuffer: createVertexBuffer,
|
|
createIndexBuffer: createIndexBuffer,
|
|
getInternalFormatForFloat32MatrixTexture: getInternalFormatForFloat32MatrixTexture,
|
|
createFloat32MatrixTexture: createFloat32MatrixTexture,
|
|
getInternalFormatForFloat16MatrixTexture: getInternalFormatForFloat16MatrixTexture,
|
|
createFloat16MatrixTexture: createFloat16MatrixTexture,
|
|
getInternalFormatForUnsignedBytesMatrixTexture: getInternalFormatForUnsignedBytesMatrixTexture,
|
|
createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture,
|
|
getInternalFormatForPackedMatrixTexture: getInternalFormatForPackedMatrixTexture,
|
|
createPackedMatrixTexture: createPackedMatrixTexture,
|
|
getInternalFormatForFloat16PackedMatrixTexture: getInternalFormatForFloat16PackedMatrixTexture,
|
|
createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture,
|
|
bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams,
|
|
uploadDenseMatrixToTexture: uploadDenseMatrixToTexture,
|
|
uploadPixelDataToTexture: uploadPixelDataToTexture,
|
|
createBufferFromOutputTexture: createBufferFromOutputTexture,
|
|
downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer,
|
|
downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture,
|
|
downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer,
|
|
downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var GPGPUContext = /** @class */ (function () {
|
|
function GPGPUContext(gl) {
|
|
this.outputTexture = null;
|
|
this.program = null;
|
|
this.disposed = false;
|
|
this.vertexAttrsAreBound = false;
|
|
this.itemsToPoll = [];
|
|
var glVersion = tf.env().getNumber('WEBGL_VERSION');
|
|
if (gl != null) {
|
|
this.gl = gl;
|
|
setWebGLContext(glVersion, gl);
|
|
}
|
|
else {
|
|
this.gl = getWebGLContext(glVersion);
|
|
}
|
|
// WebGL 2.0 enables texture floats without an extension.
|
|
var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float';
|
|
var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float';
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 1) {
|
|
var TEXTURE_FLOAT = 'OES_texture_float';
|
|
var TEXTURE_HALF_FLOAT = 'OES_texture_half_float';
|
|
this.textureFloatExtension =
|
|
getExtensionOrThrow(this.gl, TEXTURE_FLOAT);
|
|
if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) {
|
|
this.textureHalfFloatExtension =
|
|
getExtensionOrThrow(this.gl, TEXTURE_HALF_FLOAT);
|
|
}
|
|
else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) {
|
|
throw new Error('GL context does not support half float textures, yet the ' +
|
|
'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
|
|
}
|
|
this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT);
|
|
if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
this.colorBufferHalfFloatExtension =
|
|
getExtensionOrThrow(this.gl, COLOR_BUFFER_HALF_FLOAT);
|
|
}
|
|
else if (tf.env().get('WEBGL_FORCE_F16_TEXTURES')) {
|
|
throw new Error('GL context does not support color renderable half floats, yet ' +
|
|
'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.');
|
|
}
|
|
}
|
|
else {
|
|
COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float';
|
|
if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) {
|
|
this.colorBufferFloatExtension =
|
|
this.gl.getExtension(COLOR_BUFFER_FLOAT);
|
|
}
|
|
else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) {
|
|
this.colorBufferHalfFloatExtension =
|
|
this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT);
|
|
}
|
|
else {
|
|
throw new Error('GL context does not support color renderable floats');
|
|
}
|
|
}
|
|
this.vertexBuffer = createVertexBuffer(this.gl);
|
|
this.indexBuffer = createIndexBuffer(this.gl);
|
|
this.framebuffer = createFramebuffer(this.gl);
|
|
this.textureConfig =
|
|
getTextureConfig(this.gl, this.textureHalfFloatExtension);
|
|
}
|
|
Object.defineProperty(GPGPUContext.prototype, "debug", {
|
|
get: function () {
|
|
return tf.env().getBool('DEBUG');
|
|
},
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
GPGPUContext.prototype.dispose = function () {
|
|
var _this = this;
|
|
if (this.disposed) {
|
|
return;
|
|
}
|
|
if (this.program != null) {
|
|
console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' +
|
|
' This is probably a resource leak, delete the program with ' +
|
|
'GPGPUContext.deleteProgram before disposing.');
|
|
}
|
|
if (this.outputTexture != null) {
|
|
console.warn('Disposing a GPGPUContext that still has a bound output matrix ' +
|
|
'texture. This is probably a resource leak, delete the output ' +
|
|
'matrix texture with GPGPUContext.deleteMatrixTexture before ' +
|
|
'disposing.');
|
|
}
|
|
var gl = this.gl;
|
|
callAndCheck(gl, function () { return gl.finish(); });
|
|
callAndCheck(gl, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); });
|
|
callAndCheck(gl, function () { return gl.deleteFramebuffer(_this.framebuffer); });
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, null); });
|
|
callAndCheck(gl, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null); });
|
|
callAndCheck(gl, function () { return gl.deleteBuffer(_this.indexBuffer); });
|
|
this.disposed = true;
|
|
};
|
|
GPGPUContext.prototype.createFloat32MatrixTexture = function (rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat32MatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.createFloat16MatrixTexture = function (rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat16MatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.createUnsignedBytesMatrixTexture = function (rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createUnsignedBytesMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.uploadPixelDataToTexture = function (texture, pixels) {
|
|
this.throwIfDisposed();
|
|
uploadPixelDataToTexture(this.gl, texture, pixels);
|
|
};
|
|
GPGPUContext.prototype.uploadDenseMatrixToTexture = function (texture, width, height, data) {
|
|
this.throwIfDisposed();
|
|
uploadDenseMatrixToTexture(this.gl, texture, width, height, data, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.createFloat16PackedMatrixTexture = function (rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createFloat16PackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.createPackedMatrixTexture = function (rows, columns) {
|
|
this.throwIfDisposed();
|
|
return createPackedMatrixTexture(this.gl, rows, columns, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.deleteMatrixTexture = function (texture) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
if (this.outputTexture === texture) {
|
|
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
|
|
this.outputTexture = null;
|
|
}
|
|
callAndCheck(this.gl, function () { return _this.gl.deleteTexture(texture); });
|
|
};
|
|
GPGPUContext.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function (texture, rows, columns) {
|
|
var _this = this;
|
|
return this.downloadMatrixDriver(texture, function () { return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, rows, columns, _this.textureConfig); });
|
|
};
|
|
GPGPUContext.prototype.downloadPackedMatrixFromBuffer = function (buffer, batch, rows, columns, physicalRows, physicalCols) {
|
|
return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig);
|
|
};
|
|
GPGPUContext.prototype.downloadFloat32MatrixFromBuffer = function (buffer, size) {
|
|
return downloadFloat32MatrixFromBuffer(this.gl, buffer, size);
|
|
};
|
|
GPGPUContext.prototype.createBufferFromTexture = function (texture, rows, columns) {
|
|
this.bindTextureToFrameBuffer(texture);
|
|
var result = createBufferFromOutputTexture(this.gl, rows, columns, this.textureConfig);
|
|
this.unbindTextureToFrameBuffer();
|
|
return result;
|
|
};
|
|
GPGPUContext.prototype.createAndWaitForFence = function () {
|
|
var fenceContext = this.createFence(this.gl);
|
|
return this.pollFence(fenceContext);
|
|
};
|
|
GPGPUContext.prototype.createFence = function (gl) {
|
|
var _this = this;
|
|
var query;
|
|
var isFencePassed;
|
|
if (tf.env().getBool('WEBGL_FENCE_API_ENABLED')) {
|
|
var gl2_1 = gl;
|
|
var sync_1 = gl2_1.fenceSync(gl2_1.SYNC_GPU_COMMANDS_COMPLETE, 0);
|
|
gl.flush();
|
|
isFencePassed = function () {
|
|
var status = gl2_1.clientWaitSync(sync_1, 0, 0);
|
|
return status === gl2_1.ALREADY_SIGNALED ||
|
|
status === gl2_1.CONDITION_SATISFIED;
|
|
};
|
|
query = sync_1;
|
|
}
|
|
else if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) {
|
|
query = this.beginQuery();
|
|
this.endQuery();
|
|
isFencePassed = function () { return _this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); };
|
|
}
|
|
else {
|
|
// If we have no way to fence, return true immediately. This will fire in
|
|
// WebGL 1.0 when there is no disjoint query timer. In this case, because
|
|
// the fence passes immediately, we'll immediately ask for a download of
|
|
// the texture, which will cause the UI thread to hang.
|
|
isFencePassed = function () { return true; };
|
|
}
|
|
return { query: query, isFencePassed: isFencePassed };
|
|
};
|
|
GPGPUContext.prototype.downloadMatrixFromPackedTexture = function (texture, physicalRows, physicalCols) {
|
|
var _this = this;
|
|
return this.downloadMatrixDriver(texture, function () { return downloadMatrixFromPackedOutputTexture(_this.gl, physicalRows, physicalCols); });
|
|
};
|
|
GPGPUContext.prototype.createProgram = function (fragmentShaderSource) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
var gl = this.gl;
|
|
var fragmentShader = createFragmentShader(gl, fragmentShaderSource);
|
|
if (this.vertexShader == null) {
|
|
this.vertexShader = createVertexShader$1(gl);
|
|
}
|
|
var program = createProgram(gl);
|
|
callAndCheck(gl, function () { return gl.attachShader(program, _this.vertexShader); });
|
|
callAndCheck(gl, function () { return gl.attachShader(program, fragmentShader); });
|
|
linkProgram(gl, program);
|
|
if (this.debug) {
|
|
validateProgram(gl, program);
|
|
}
|
|
if (!this.vertexAttrsAreBound) {
|
|
this.setProgram(program);
|
|
this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.program, this.vertexBuffer);
|
|
}
|
|
return program;
|
|
};
|
|
GPGPUContext.prototype.deleteProgram = function (program) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
if (program === this.program) {
|
|
this.program = null;
|
|
}
|
|
if (program != null) {
|
|
callAndCheck(this.gl, function () { return _this.gl.deleteProgram(program); });
|
|
}
|
|
};
|
|
GPGPUContext.prototype.setProgram = function (program) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
this.program = program;
|
|
if ((this.program != null) && this.debug) {
|
|
validateProgram(this.gl, this.program);
|
|
}
|
|
callAndCheck(this.gl, function () { return _this.gl.useProgram(program); });
|
|
};
|
|
GPGPUContext.prototype.getUniformLocation = function (program, uniformName, shouldThrow) {
|
|
if (shouldThrow === void 0) { shouldThrow = true; }
|
|
this.throwIfDisposed();
|
|
if (shouldThrow) {
|
|
return getProgramUniformLocationOrThrow(this.gl, program, uniformName);
|
|
}
|
|
else {
|
|
return getProgramUniformLocation(this.gl, program, uniformName);
|
|
}
|
|
};
|
|
GPGPUContext.prototype.getAttributeLocation = function (program, attribute) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
return callAndCheck(this.gl, function () { return _this.gl.getAttribLocation(program, attribute); });
|
|
};
|
|
GPGPUContext.prototype.getUniformLocationNoThrow = function (program, uniformName) {
|
|
this.throwIfDisposed();
|
|
return this.gl.getUniformLocation(program, uniformName);
|
|
};
|
|
GPGPUContext.prototype.setInputMatrixTexture = function (inputMatrixTexture, uniformLocation, textureUnit) {
|
|
this.throwIfDisposed();
|
|
this.throwIfNoProgram();
|
|
bindTextureToProgramUniformSampler(this.gl, inputMatrixTexture, uniformLocation, textureUnit);
|
|
};
|
|
GPGPUContext.prototype.setOutputMatrixTexture = function (outputMatrixTexture, rows, columns) {
|
|
this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows);
|
|
};
|
|
GPGPUContext.prototype.setOutputPackedMatrixTexture = function (outputPackedMatrixTexture, rows, columns) {
|
|
this.throwIfDisposed();
|
|
var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1];
|
|
this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height);
|
|
};
|
|
GPGPUContext.prototype.setOutputMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) {
|
|
this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows);
|
|
};
|
|
GPGPUContext.prototype.setOutputPackedMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) {
|
|
throw new Error('setOutputPackedMatrixWriteRegion not implemented.');
|
|
};
|
|
GPGPUContext.prototype.debugValidate = function () {
|
|
if (this.program != null) {
|
|
validateProgram(this.gl, this.program);
|
|
}
|
|
validateFramebuffer(this.gl);
|
|
};
|
|
GPGPUContext.prototype.executeProgram = function () {
|
|
this.throwIfDisposed();
|
|
this.throwIfNoProgram();
|
|
var gl = this.gl;
|
|
if (this.debug) {
|
|
this.debugValidate();
|
|
}
|
|
callAndCheck(gl, function () { return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0); });
|
|
};
|
|
GPGPUContext.prototype.blockUntilAllProgramsCompleted = function () {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
callAndCheck(this.gl, function () { return _this.gl.finish(); });
|
|
};
|
|
GPGPUContext.prototype.getQueryTimerExtension = function () {
|
|
if (this.disjointQueryTimerExtension == null) {
|
|
this.disjointQueryTimerExtension =
|
|
getExtensionOrThrow(this.gl, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ?
|
|
'EXT_disjoint_timer_query_webgl2' :
|
|
'EXT_disjoint_timer_query');
|
|
}
|
|
return this.disjointQueryTimerExtension;
|
|
};
|
|
GPGPUContext.prototype.getQueryTimerExtensionWebGL2 = function () {
|
|
return this.getQueryTimerExtension();
|
|
};
|
|
GPGPUContext.prototype.getQueryTimerExtensionWebGL1 = function () {
|
|
return this.getQueryTimerExtension();
|
|
};
|
|
GPGPUContext.prototype.beginQuery = function () {
|
|
if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
|
|
var gl2 = this.gl;
|
|
var ext_1 = this.getQueryTimerExtensionWebGL2();
|
|
var query_1 = gl2.createQuery();
|
|
gl2.beginQuery(ext_1.TIME_ELAPSED_EXT, query_1);
|
|
return query_1;
|
|
}
|
|
var ext = this.getQueryTimerExtensionWebGL1();
|
|
var query = ext.createQueryEXT();
|
|
ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query);
|
|
return query;
|
|
};
|
|
GPGPUContext.prototype.endQuery = function () {
|
|
if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) {
|
|
var gl2 = this.gl;
|
|
var ext_2 = this.getQueryTimerExtensionWebGL2();
|
|
gl2.endQuery(ext_2.TIME_ELAPSED_EXT);
|
|
return;
|
|
}
|
|
var ext = this.getQueryTimerExtensionWebGL1();
|
|
ext.endQueryEXT(ext.TIME_ELAPSED_EXT);
|
|
};
|
|
GPGPUContext.prototype.waitForQueryAndGetTime = function (query) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var _this = this;
|
|
return __generator(this, function (_a) {
|
|
switch (_a.label) {
|
|
case 0: return [4 /*yield*/, tf.util.repeatedTry(function () { return _this.disposed || // while testing contexts are created / disposed
|
|
// in rapid succession, so without this check we
|
|
// may poll for the query timer indefinitely
|
|
_this.isQueryAvailable(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); })];
|
|
case 1:
|
|
_a.sent();
|
|
return [2 /*return*/, this.getQueryTime(query, tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
GPGPUContext.prototype.getQueryTime = function (query, queryTimerVersion) {
|
|
if (queryTimerVersion === 0) {
|
|
return null;
|
|
}
|
|
if (queryTimerVersion === 2) {
|
|
var gl2 = this.gl;
|
|
var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT);
|
|
// Return milliseconds.
|
|
return timeElapsedNanos / 1000000;
|
|
}
|
|
else {
|
|
var ext = this.getQueryTimerExtensionWebGL1();
|
|
var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT);
|
|
// Return milliseconds.
|
|
return timeElapsedNanos / 1000000;
|
|
}
|
|
};
|
|
GPGPUContext.prototype.isQueryAvailable = function (query, queryTimerVersion) {
|
|
if (queryTimerVersion === 0) {
|
|
return true;
|
|
}
|
|
if (queryTimerVersion === 2) {
|
|
var gl2 = this.gl;
|
|
var ext = this.getQueryTimerExtensionWebGL2();
|
|
var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE);
|
|
if (this.disjoint == null) {
|
|
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
|
|
}
|
|
return available && !this.disjoint;
|
|
}
|
|
else {
|
|
var ext = this.getQueryTimerExtensionWebGL1();
|
|
var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT);
|
|
if (this.disjoint == null) {
|
|
this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT);
|
|
}
|
|
return available && !this.disjoint;
|
|
}
|
|
};
|
|
GPGPUContext.prototype.pollFence = function (fenceContext) {
|
|
var _this = this;
|
|
return new Promise(function (resolve) {
|
|
_this.addItemToPoll(function () { return fenceContext.isFencePassed(); }, function () { return resolve(); });
|
|
});
|
|
};
|
|
GPGPUContext.prototype.pollItems = function () {
|
|
// Find the last query that has finished.
|
|
var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) { return x.isDoneFn; }));
|
|
for (var i = 0; i <= index; ++i) {
|
|
var resolveFn = this.itemsToPoll[i].resolveFn;
|
|
resolveFn();
|
|
}
|
|
this.itemsToPoll = this.itemsToPoll.slice(index + 1);
|
|
};
|
|
GPGPUContext.prototype.addItemToPoll = function (isDoneFn, resolveFn) {
|
|
var _this = this;
|
|
this.itemsToPoll.push({ isDoneFn: isDoneFn, resolveFn: resolveFn });
|
|
if (this.itemsToPoll.length > 1) {
|
|
// We already have a running loop that polls.
|
|
return;
|
|
}
|
|
// Start a new loop that polls.
|
|
tf.util.repeatedTry(function () {
|
|
_this.pollItems();
|
|
// End the loop if no more items to poll.
|
|
return _this.itemsToPoll.length === 0;
|
|
});
|
|
};
|
|
GPGPUContext.prototype.bindTextureToFrameBuffer = function (texture) {
|
|
this.throwIfDisposed();
|
|
bindColorTextureToFramebuffer(this.gl, texture, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(this.gl);
|
|
}
|
|
};
|
|
GPGPUContext.prototype.unbindTextureToFrameBuffer = function () {
|
|
if (this.outputTexture != null) {
|
|
bindColorTextureToFramebuffer(this.gl, this.outputTexture, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(this.gl);
|
|
}
|
|
}
|
|
else {
|
|
unbindColorTextureFromFramebuffer(this.gl, this.framebuffer);
|
|
}
|
|
};
|
|
GPGPUContext.prototype.downloadMatrixDriver = function (texture, downloadAndDecode) {
|
|
this.bindTextureToFrameBuffer(texture);
|
|
var result = downloadAndDecode();
|
|
this.unbindTextureToFrameBuffer();
|
|
return result;
|
|
};
|
|
GPGPUContext.prototype.setOutputMatrixTextureDriver = function (outputMatrixTextureMaybePacked, width, height) {
|
|
this.throwIfDisposed();
|
|
var gl = this.gl;
|
|
bindColorTextureToFramebuffer(gl, outputMatrixTextureMaybePacked, this.framebuffer);
|
|
if (this.debug) {
|
|
validateFramebuffer(gl);
|
|
}
|
|
this.outputTexture = outputMatrixTextureMaybePacked;
|
|
callAndCheck(gl, function () { return gl.viewport(0, 0, width, height); });
|
|
callAndCheck(gl, function () { return gl.scissor(0, 0, width, height); });
|
|
};
|
|
GPGPUContext.prototype.setOutputMatrixWriteRegionDriver = function (x, y, width, height) {
|
|
var _this = this;
|
|
this.throwIfDisposed();
|
|
callAndCheck(this.gl, function () { return _this.gl.scissor(x, y, width, height); });
|
|
};
|
|
GPGPUContext.prototype.throwIfDisposed = function () {
|
|
if (this.disposed) {
|
|
throw new Error('Attempted to use disposed GPGPUContext.');
|
|
}
|
|
};
|
|
GPGPUContext.prototype.throwIfNoProgram = function () {
|
|
if (this.program == null) {
|
|
throw new Error('No GPU program is currently set.');
|
|
}
|
|
};
|
|
return GPGPUContext;
|
|
}());
|
|
/**
|
|
* Finds the index of the last true element using linear search.
|
|
* Note: We can't do binary search because Chrome expects us to explicitly
|
|
* test all fences before download:
|
|
* https://github.com/tensorflow/tfjs/issues/1145
|
|
*/
|
|
function linearSearchLastTrue(arr) {
|
|
var i = 0;
|
|
for (; i < arr.length; ++i) {
|
|
var isDone = arr[i]();
|
|
if (!isDone) {
|
|
break;
|
|
}
|
|
}
|
|
return i - 1;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var getBroadcastDims = tf.backend_util.getBroadcastDims;
|
|
function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) {
|
|
var prefixSnippets = [];
|
|
inputsInfo.forEach(function (x) {
|
|
var size = tf.util.sizeFromShape(x.shapeInfo.logicalShape);
|
|
// Snippet when we decided to upload the values as uniform.
|
|
if (x.shapeInfo.isUniform) {
|
|
prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";");
|
|
}
|
|
else {
|
|
prefixSnippets.push("uniform sampler2D " + x.name + ";");
|
|
prefixSnippets.push("uniform int offset" + x.name + ";");
|
|
}
|
|
});
|
|
var inputPrefixSnippet = prefixSnippets.join('\n');
|
|
var inputSamplingSnippet = inputsInfo
|
|
.map(function (x) { return getInputSamplingSnippet(x, outputShape, usesPackedTextures); })
|
|
.join('\n');
|
|
var outTexShape = outputShape.texShape;
|
|
var glsl = getGlslDifferences();
|
|
var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl);
|
|
var outputSamplingSnippet;
|
|
var floatTextureSetOutputSnippet;
|
|
var shaderPrefix = getShaderPrefix(glsl);
|
|
if (outputShape.isPacked) {
|
|
outputSamplingSnippet =
|
|
getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
|
|
floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl);
|
|
}
|
|
else {
|
|
outputSamplingSnippet =
|
|
getOutputSamplingSnippet(outputShape.logicalShape, outTexShape);
|
|
floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl);
|
|
}
|
|
if (usesPackedTextures) {
|
|
shaderPrefix += SHADER_PACKED_PREFIX;
|
|
}
|
|
var source = [
|
|
shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet,
|
|
inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode
|
|
].join('\n');
|
|
return source;
|
|
}
|
|
function getSamplerFromInInfo(inInfo) {
|
|
var shape = inInfo.shapeInfo.logicalShape;
|
|
switch (shape.length) {
|
|
case 0:
|
|
return getSamplerScalar(inInfo);
|
|
case 1:
|
|
return getSampler1D(inInfo);
|
|
case 2:
|
|
return getSampler2D(inInfo);
|
|
case 3:
|
|
return getSampler3D(inInfo);
|
|
case 4:
|
|
return getSampler4D(inInfo);
|
|
case 5:
|
|
return getSampler5D(inInfo);
|
|
case 6:
|
|
return getSampler6D(inInfo);
|
|
default:
|
|
throw new Error(shape.length + "-D input sampling" +
|
|
" is not yet supported");
|
|
}
|
|
}
|
|
function getPackedSamplerFromInInfo(inInfo) {
|
|
var shape = inInfo.shapeInfo.logicalShape;
|
|
switch (shape.length) {
|
|
case 0:
|
|
return getPackedSamplerScalar(inInfo);
|
|
case 1:
|
|
return getPackedSampler1D(inInfo);
|
|
case 2:
|
|
return getPackedSampler2D(inInfo);
|
|
case 3:
|
|
return getPackedSampler3D(inInfo);
|
|
default:
|
|
return getPackedSamplerND(inInfo);
|
|
}
|
|
}
|
|
function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures) {
|
|
if (usesPackedTextures === void 0) { usesPackedTextures = false; }
|
|
var res = '';
|
|
if (usesPackedTextures) {
|
|
res += getPackedSamplerFromInInfo(inInfo);
|
|
}
|
|
else {
|
|
res += getSamplerFromInInfo(inInfo);
|
|
}
|
|
var inShape = inInfo.shapeInfo.logicalShape;
|
|
var outShape = outShapeInfo.logicalShape;
|
|
if (inShape.length <= outShape.length) {
|
|
if (usesPackedTextures) {
|
|
res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo);
|
|
}
|
|
else {
|
|
res += getSamplerAtOutputCoords(inInfo, outShapeInfo);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
function getPackedOutputSamplingSnippet(outShape, outTexShape) {
|
|
switch (outShape.length) {
|
|
case 0:
|
|
return getOutputScalarCoords();
|
|
case 1:
|
|
return getOutputPacked1DCoords(outShape, outTexShape);
|
|
case 2:
|
|
return getOutputPacked2DCoords(outShape, outTexShape);
|
|
case 3:
|
|
return getOutputPacked3DCoords(outShape, outTexShape);
|
|
default:
|
|
return getOutputPackedNDCoords(outShape, outTexShape);
|
|
}
|
|
}
|
|
function getOutputSamplingSnippet(outShape, outTexShape) {
|
|
switch (outShape.length) {
|
|
case 0:
|
|
return getOutputScalarCoords();
|
|
case 1:
|
|
return getOutput1DCoords(outShape, outTexShape);
|
|
case 2:
|
|
return getOutput2DCoords(outShape, outTexShape);
|
|
case 3:
|
|
return getOutput3DCoords(outShape, outTexShape);
|
|
case 4:
|
|
return getOutput4DCoords(outShape, outTexShape);
|
|
case 5:
|
|
return getOutput5DCoords(outShape, outTexShape);
|
|
case 6:
|
|
return getOutput6DCoords(outShape, outTexShape);
|
|
default:
|
|
throw new Error(outShape.length + "-D output sampling is not yet supported");
|
|
}
|
|
}
|
|
function getFloatTextureSampleSnippet(glsl) {
|
|
return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n ";
|
|
}
|
|
function getFloatTextureSetRSnippet(glsl) {
|
|
return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n ";
|
|
}
|
|
function getFloatTextureSetRGBASnippet(glsl) {
|
|
return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n ";
|
|
}
|
|
function getShaderPrefix(glsl) {
|
|
var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n ";
|
|
return SHADER_PREFIX;
|
|
}
|
|
var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
|
|
var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
|
|
var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n";
|
|
var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n";
|
|
function getOutputScalarCoords() {
|
|
return "\n int getOutputCoords() {\n return 0;\n }\n ";
|
|
}
|
|
function getOutputPacked1DCoords(shape, texShape) {
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (packedTexShape[0] === 1) {
|
|
return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n ";
|
|
}
|
|
if (packedTexShape[1] === 1) {
|
|
return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n ";
|
|
}
|
|
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n ";
|
|
}
|
|
function getOutput1DCoords(shape, texShape) {
|
|
if (texShape[0] === 1) {
|
|
return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n ";
|
|
}
|
|
if (texShape[1] === 1) {
|
|
return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n ";
|
|
}
|
|
return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n ";
|
|
}
|
|
function getOutputPacked3DCoords(shape, texShape) {
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
var texelsInLogicalRow = Math.ceil(shape[2] / 2);
|
|
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2);
|
|
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n ";
|
|
}
|
|
function getOutput3DCoords(shape, texShape) {
|
|
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
|
|
return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
|
|
}
|
|
function getOutputPackedNDCoords(shape, texShape) {
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2);
|
|
var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2);
|
|
var texelsInBatchN = texelsInBatch;
|
|
var batches = "";
|
|
var coords = 'b, r, c';
|
|
for (var b = 2; b < shape.length - 1; b++) {
|
|
texelsInBatchN *= shape[shape.length - b - 1];
|
|
batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches;
|
|
coords = "b" + b + ", " + coords;
|
|
}
|
|
return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords + ");\n }\n ";
|
|
}
|
|
function getOutput4DCoords(shape, texShape) {
|
|
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape);
|
|
return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n ";
|
|
}
|
|
function getOutput5DCoords(shape, texShape) {
|
|
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape);
|
|
return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n ";
|
|
}
|
|
function getOutput6DCoords(shape, texShape) {
|
|
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape);
|
|
return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n ";
|
|
}
|
|
function getOutputPacked2DCoords(shape, texShape) {
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (tf.util.arraysEqual(shape, texShape)) {
|
|
return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n ";
|
|
}
|
|
// texels needed to accommodate a logical row
|
|
var texelsInLogicalRow = Math.ceil(shape[1] / 2);
|
|
/**
|
|
* getOutputCoords
|
|
*
|
|
* resTexRC: The rows and columns of the texels. If you move over one
|
|
* texel to the right in the packed texture, you are moving over one column
|
|
* (not two).
|
|
*
|
|
* index: The texel index
|
|
*/
|
|
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n ";
|
|
}
|
|
function getOutput2DCoords(shape, texShape) {
|
|
if (tf.util.arraysEqual(shape, texShape)) {
|
|
return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n ";
|
|
}
|
|
if (shape[1] === 1) {
|
|
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n ";
|
|
}
|
|
if (shape[0] === 1) {
|
|
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n ";
|
|
}
|
|
return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n ";
|
|
}
|
|
function getFlatOffsetUniformName(texName) {
|
|
return "offset" + texName;
|
|
}
|
|
function getPackedSamplerScalar(inputInfo) {
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var glsl = getGlslDifferences();
|
|
return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n ";
|
|
}
|
|
function getSamplerScalar(inputInfo) {
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
return "float " + funcName + "() {return " + texName + ";}";
|
|
}
|
|
var _a = inputInfo.shapeInfo.texShape, texNumR = _a[0], texNumC = _a[1];
|
|
if (texNumR === 1 && texNumC === 1) {
|
|
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
|
|
}
|
|
var _b = inputInfo.shapeInfo.texShape, tNumR = _b[0], tNumC = _b[1];
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getPackedSampler1D(inputInfo) {
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
var glsl = getGlslDifferences();
|
|
return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler1D(inputInfo) {
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var tNumR = texShape[0];
|
|
var tNumC = texShape[1];
|
|
if (tNumC === 1 && tNumR === 1) {
|
|
return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n ";
|
|
}
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
if (tNumC === 1) {
|
|
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (tNumR === 1) {
|
|
return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getPackedSampler2D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
var glsl = getGlslDifferences();
|
|
if (texShape != null && tf.util.arraysEqual(shape, texShape)) {
|
|
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
var valuesPerRow = Math.ceil(shape[1] / 2);
|
|
return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler2D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
if (texShape != null && tf.util.arraysEqual(shape, texShape)) {
|
|
var texNumR_1 = texShape[0];
|
|
var texNumC_1 = texShape[1];
|
|
return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
|
|
var squeezedShape = newShape;
|
|
if (squeezedShape.length < shape.length) {
|
|
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
var params = ['row', 'col'];
|
|
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
if (texNumC === 1) {
|
|
// index is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (texNumR === 1) {
|
|
// index is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n";
|
|
}
|
|
function getPackedSampler3D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
if (shape[0] === 1) {
|
|
var squeezedShape = shape.slice(1);
|
|
var keptDims = [1, 2];
|
|
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
var params = ['b', 'row', 'col'];
|
|
return "\n " + getPackedSamplerFromInInfo(newInputInfo) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
var texNumR = packedTexShape[0];
|
|
var texNumC = packedTexShape[1];
|
|
var valuesPerRow = Math.ceil(shape[2] / 2);
|
|
var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2);
|
|
var glsl = getGlslDifferences();
|
|
return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler3D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var stride0 = shape[1] * shape[2];
|
|
var stride1 = shape[2];
|
|
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
|
|
var squeezedShape = newShape;
|
|
if (squeezedShape.length < shape.length) {
|
|
var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape);
|
|
var params = ['row', 'col', 'depth'];
|
|
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
var flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
// texC is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (texNumC === stride1 && flatOffset == null) {
|
|
// texR is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getPackedSamplerND(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var rank = shape.length;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)];
|
|
var texNumR = packedTexShape[0];
|
|
var texNumC = packedTexShape[1];
|
|
var valuesPerRow = Math.ceil(shape[rank - 1] / 2);
|
|
var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2);
|
|
var params = "int b, int row, int col";
|
|
var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)";
|
|
for (var b = 2; b < rank - 1; b++) {
|
|
params = "int b" + b + ", " + params;
|
|
texelsInBatch *= shape[rank - b - 1];
|
|
index = "b" + b + " * " + texelsInBatch + " + " + index;
|
|
}
|
|
var glsl = getGlslDifferences();
|
|
return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler4D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var stride2 = shape[3];
|
|
var stride1 = shape[2] * stride2;
|
|
var stride0 = shape[1] * stride1;
|
|
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
|
|
if (newShape.length < shape.length) {
|
|
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
var params = ['row', 'col', 'depth', 'depth2'];
|
|
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
// texC is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (texNumC === stride2 && flatOffset == null) {
|
|
// texR is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler5D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var stride3 = shape[4];
|
|
var stride2 = shape[3] * stride3;
|
|
var stride1 = shape[2] * stride2;
|
|
var stride0 = shape[1] * stride1;
|
|
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
|
|
if (newShape.length < shape.length) {
|
|
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
var params = ['row', 'col', 'depth', 'depth2', 'depth3'];
|
|
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
// texC is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (texNumC === stride3 && flatOffset == null) {
|
|
// texR is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getSampler6D(inputInfo) {
|
|
var shape = inputInfo.shapeInfo.logicalShape;
|
|
var texName = inputInfo.name;
|
|
var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var _a = tf.util.squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims;
|
|
if (newShape.length < shape.length) {
|
|
var newInputInfo = squeezeInputInfo(inputInfo, newShape);
|
|
var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4'];
|
|
return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n ";
|
|
}
|
|
var stride4 = shape[5];
|
|
var stride3 = shape[4] * stride4;
|
|
var stride2 = shape[3] * stride3;
|
|
var stride1 = shape[2] * stride2;
|
|
var stride0 = shape[1] * stride1;
|
|
if (inputInfo.shapeInfo.isUniform) {
|
|
// Uniform arrays will be less than 65505 (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n ";
|
|
}
|
|
var flatOffset = inputInfo.shapeInfo.flatOffset;
|
|
var texShape = inputInfo.shapeInfo.texShape;
|
|
var texNumR = texShape[0];
|
|
var texNumC = texShape[1];
|
|
if (texNumC === stride0 && flatOffset == null) {
|
|
// texC is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
if (texNumC === stride4 && flatOffset == null) {
|
|
// texR is used directly as physical (no risk of float16 overflow).
|
|
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
var offset = getFlatOffsetUniformName(texName);
|
|
return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n ";
|
|
}
|
|
function getUniformSampler(inputInfo) {
|
|
var texName = inputInfo.name;
|
|
var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
|
|
if (inSize < 2) {
|
|
return "return " + texName + ";";
|
|
}
|
|
return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n ";
|
|
}
|
|
function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) {
|
|
var texName = inputInfo.name;
|
|
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
|
|
var inRank = inputInfo.shapeInfo.logicalShape.length;
|
|
var outRank = outShapeInfo.logicalShape.length;
|
|
var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
|
|
var type = getCoordsDataType(outRank);
|
|
var rankDiff = outRank - inRank;
|
|
var coordsSnippet;
|
|
var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
if (inRank === 0) {
|
|
coordsSnippet = '';
|
|
}
|
|
else if (outRank < 2 && broadcastDims.length >= 1) {
|
|
coordsSnippet = 'coords = 0;';
|
|
}
|
|
else {
|
|
coordsSnippet =
|
|
broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; })
|
|
.join('\n');
|
|
}
|
|
var unpackedCoordsSnippet = '';
|
|
if (outRank < 2 && inRank > 0) {
|
|
unpackedCoordsSnippet = 'coords';
|
|
}
|
|
else {
|
|
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
|
|
.map(function (s, i) { return "coords." + fields[i + rankDiff]; })
|
|
.join(', ');
|
|
}
|
|
var output = "return outputValue;";
|
|
var inSize = tf.util.sizeFromShape(inputInfo.shapeInfo.logicalShape);
|
|
var isInputScalar = inSize === 1;
|
|
var outSize = tf.util.sizeFromShape(outShapeInfo.logicalShape);
|
|
var isOutputScalar = outSize === 1;
|
|
if (inRank === 1 && !isInputScalar && !isOutputScalar) {
|
|
output = "\n return vec4(outputValue.xy, outputValue.xy);\n ";
|
|
}
|
|
else if (isInputScalar && !isOutputScalar) {
|
|
if (outRank === 1) {
|
|
output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n ";
|
|
}
|
|
else {
|
|
output = "\n return vec4(outputValue.x);\n ";
|
|
}
|
|
}
|
|
else if (broadcastDims.length) {
|
|
var rows = inRank - 2;
|
|
var cols = inRank - 1;
|
|
if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) {
|
|
output = "return vec4(outputValue.x);";
|
|
}
|
|
else if (broadcastDims.indexOf(rows) > -1) {
|
|
output = "return vec4(outputValue.x, outputValue.y, " +
|
|
"outputValue.x, outputValue.y);";
|
|
}
|
|
else if (broadcastDims.indexOf(cols) > -1) {
|
|
output = "return vec4(outputValue.xx, outputValue.zz);";
|
|
}
|
|
}
|
|
return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n ";
|
|
}
|
|
function getSamplerAtOutputCoords(inputInfo, outShapeInfo) {
|
|
var texName = inputInfo.name;
|
|
var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1);
|
|
var funcName = 'get' + texFuncSnippet + 'AtOutCoords';
|
|
var outTexShape = outShapeInfo.texShape;
|
|
var inTexShape = inputInfo.shapeInfo.texShape;
|
|
var inRank = inputInfo.shapeInfo.logicalShape.length;
|
|
var outRank = outShapeInfo.logicalShape.length;
|
|
if (!inputInfo.shapeInfo.isUniform && inRank === outRank &&
|
|
inputInfo.shapeInfo.flatOffset == null &&
|
|
tf.util.arraysEqual(inTexShape, outTexShape)) {
|
|
return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n ";
|
|
}
|
|
var type = getCoordsDataType(outRank);
|
|
var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape);
|
|
var rankDiff = outRank - inRank;
|
|
var coordsSnippet;
|
|
var fields = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
if (inRank === 0) {
|
|
coordsSnippet = '';
|
|
}
|
|
else if (outRank < 2 && broadcastDims.length >= 1) {
|
|
coordsSnippet = 'coords = 0;';
|
|
}
|
|
else {
|
|
coordsSnippet =
|
|
broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; })
|
|
.join('\n');
|
|
}
|
|
var unpackedCoordsSnippet = '';
|
|
if (outRank < 2 && inRank > 0) {
|
|
unpackedCoordsSnippet = 'coords';
|
|
}
|
|
else {
|
|
unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape
|
|
.map(function (s, i) { return "coords." + fields[i + rankDiff]; })
|
|
.join(', ');
|
|
}
|
|
return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n ";
|
|
}
|
|
function getCoordsDataType(rank) {
|
|
if (rank <= 1) {
|
|
return 'int';
|
|
}
|
|
else if (rank === 2) {
|
|
return 'ivec2';
|
|
}
|
|
else if (rank === 3) {
|
|
return 'ivec3';
|
|
}
|
|
else if (rank === 4) {
|
|
return 'ivec4';
|
|
}
|
|
else if (rank === 5) {
|
|
return 'ivec5';
|
|
}
|
|
else if (rank === 6) {
|
|
return 'ivec6';
|
|
}
|
|
else {
|
|
throw Error("GPU for rank " + rank + " is not yet supported");
|
|
}
|
|
}
|
|
/** Returns a new input info (a copy) that has a squeezed logical shape. */
|
|
function squeezeInputInfo(inInfo, squeezedShape) {
|
|
// Deep copy.
|
|
var newInputInfo = JSON.parse(JSON.stringify(inInfo));
|
|
newInputInfo.shapeInfo.logicalShape = squeezedShape;
|
|
return newInputInfo;
|
|
}
|
|
function getSqueezedParams(params, keptDims) {
|
|
return keptDims.map(function (d) { return params[d]; }).join(', ');
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function compileProgram(gpgpu, program, inputs, output) {
|
|
var userCode = program.userCode;
|
|
var inputInfos = inputs.map(function (input, i) {
|
|
var shapeInfo = {
|
|
logicalShape: input.shape,
|
|
texShape: input.isUniform ? null : input.texData.texShape,
|
|
isUniform: input.isUniform,
|
|
isPacked: input.isUniform ? false : input.texData.isPacked,
|
|
flatOffset: null
|
|
};
|
|
if (input.texData != null && input.texData.slice != null &&
|
|
input.texData.slice.flatOffset > 0) {
|
|
shapeInfo.flatOffset = input.texData.slice.flatOffset;
|
|
}
|
|
return { name: program.variableNames[i], shapeInfo: shapeInfo };
|
|
});
|
|
var inShapeInfos = inputInfos.map(function (x) { return x.shapeInfo; });
|
|
var outShapeInfo = {
|
|
logicalShape: output.shape,
|
|
texShape: output.texData.texShape,
|
|
isUniform: false,
|
|
isPacked: output.texData.isPacked,
|
|
flatOffset: null
|
|
};
|
|
var source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs);
|
|
var webGLProgram = gpgpu.createProgram(source);
|
|
// Add special uniforms (NAN, INFINITY)
|
|
var infLoc = null;
|
|
var nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false);
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 1) {
|
|
infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false);
|
|
}
|
|
// Add user-defined uniforms
|
|
var uniformLocations = {};
|
|
for (var i = 0; i < program.variableNames.length; i++) {
|
|
var varName = program.variableNames[i];
|
|
var shouldThrow = false;
|
|
uniformLocations[varName] =
|
|
gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow);
|
|
uniformLocations["offset" + varName] =
|
|
gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow);
|
|
}
|
|
return {
|
|
program: program,
|
|
source: source,
|
|
webGLProgram: webGLProgram,
|
|
uniformLocations: uniformLocations,
|
|
inShapeInfos: inShapeInfos,
|
|
outShapeInfo: outShapeInfo,
|
|
infLoc: infLoc,
|
|
nanLoc: nanLoc,
|
|
};
|
|
}
|
|
function validateBinaryAndProgram(shapeInfos, inputs) {
|
|
if (shapeInfos.length !== inputs.length) {
|
|
throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " +
|
|
("was executed with " + inputs.length + " inputs"));
|
|
}
|
|
shapeInfos.forEach(function (s, i) {
|
|
var shapeA = s.logicalShape;
|
|
var input = inputs[i];
|
|
var shapeB = input.shape;
|
|
if (!tf.util.arraysEqual(shapeA, shapeB)) {
|
|
throw Error("Binary was compiled with different shapes than " +
|
|
("the current args. Shapes " + shapeA + " and " + shapeB + " must match"));
|
|
}
|
|
// The input is uploaded as uniform.
|
|
if (s.isUniform && input.isUniform) {
|
|
return;
|
|
}
|
|
var texShapeA = s.texShape;
|
|
var texShapeB = input.isUniform ? null : input.texData.texShape;
|
|
if (!tf.util.arraysEqual(texShapeA, texShapeB)) {
|
|
throw Error("Binary was compiled with different texture shapes than the" +
|
|
(" current args. Shape " + texShapeA + " and " + texShapeB + " must match"));
|
|
}
|
|
});
|
|
}
|
|
function runProgram(gpgpu, binary, inputs, output, customSetup) {
|
|
validateBinaryAndProgram(binary.inShapeInfos, inputs);
|
|
validateBinaryAndProgram([binary.outShapeInfo], [output]);
|
|
var outTex = output.texData.texture;
|
|
var outTexShape = output.texData.texShape;
|
|
if (output.texData.isPacked) {
|
|
gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
|
|
}
|
|
else {
|
|
gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]);
|
|
}
|
|
gpgpu.setProgram(binary.webGLProgram);
|
|
// Set special uniforms (NAN, INFINITY)
|
|
if (tf.env().getNumber('WEBGL_VERSION') === 1) {
|
|
if (binary.infLoc !== null) {
|
|
gpgpu.gl.uniform1f(binary.infLoc, Infinity);
|
|
}
|
|
}
|
|
if (binary.nanLoc !== null) {
|
|
gpgpu.gl.uniform1f(binary.nanLoc, NaN);
|
|
}
|
|
// Set user-defined inputs
|
|
inputs.forEach(function (input, i) {
|
|
var varName = binary.program.variableNames[i];
|
|
var varLoc = binary.uniformLocations[varName];
|
|
var varOffsetLoc = binary.uniformLocations["offset" + varName];
|
|
if (varLoc == null) {
|
|
// The compiler inferred that this variable is not used in this shader.
|
|
return;
|
|
}
|
|
if (input.isUniform) {
|
|
// Upload the values of the tensor as uniform.
|
|
if (tf.util.sizeFromShape(input.shape) < 2) {
|
|
gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]);
|
|
}
|
|
else {
|
|
var vals = input.uniformValues;
|
|
if (!(vals instanceof Float32Array)) {
|
|
vals = new Float32Array(vals);
|
|
}
|
|
gpgpu.gl.uniform1fv(varLoc, vals);
|
|
}
|
|
return;
|
|
}
|
|
// If the input was sliced, upload the flat offset index.
|
|
if (input.texData.slice != null && varOffsetLoc != null) {
|
|
gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset);
|
|
}
|
|
gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i);
|
|
});
|
|
if (customSetup != null) {
|
|
customSetup(gpgpu, binary.webGLProgram);
|
|
}
|
|
gpgpu.executeProgram();
|
|
}
|
|
function makeShaderKey(program, inputs, output) {
|
|
var keyInputs = '';
|
|
inputs.concat(output).forEach(function (x) {
|
|
var hasOffset = x.texData != null && x.texData.slice != null &&
|
|
x.texData.slice.flatOffset > 0;
|
|
var texShape = x.isUniform ? 'uniform' : x.texData.texShape;
|
|
keyInputs += x.shape + "_" + texShape + "_" + hasOffset;
|
|
});
|
|
var keyUserCode = program.userCode;
|
|
var key = program.constructor.name;
|
|
// Fast string concat. See https://jsperf.com/string-concatenation/14.
|
|
key += '_' + keyInputs + '_' + keyUserCode;
|
|
return key;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function simpleAbsImpl(vals) {
|
|
const resultValues = new Float32Array(vals.length);
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
resultValues[i] = Math.abs(vals[i]);
|
|
}
|
|
return resultValues;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* Template that creates implementation for binary ops. Supports broadcast.
|
|
*/
|
|
function createSimpleBinaryKernelImpl(op) {
|
|
return (aShape, bShape, aVals, bVals, dtype) => {
|
|
const newShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
|
const resultRank = newShape.length;
|
|
const resultStrides = tf.util.computeStrides(newShape);
|
|
const resultSize = tf.util.sizeFromShape(newShape);
|
|
const result = tf.util.getTypedArrayFromDType(dtype, resultSize);
|
|
const aRank = aShape.length;
|
|
const bRank = bShape.length;
|
|
const aStrides = tf.util.computeStrides(aShape);
|
|
const bStrides = tf.util.computeStrides(bShape);
|
|
const aBroadcastDims = tf.backend_util.getBroadcastDims(aShape, newShape);
|
|
const bBroadcastDims = tf.backend_util.getBroadcastDims(bShape, newShape);
|
|
if (aBroadcastDims.length + bBroadcastDims.length === 0) {
|
|
for (let i = 0; i < result.length; ++i) {
|
|
result[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]);
|
|
}
|
|
}
|
|
else {
|
|
for (let i = 0; i < result.length; ++i) {
|
|
const loc = tf.util.indexToLoc(i, resultRank, resultStrides);
|
|
const aLoc = loc.slice(-aRank);
|
|
aBroadcastDims.forEach(d => aLoc[d] = 0);
|
|
const aIndex = tf.util.locToIndex(aLoc, aRank, aStrides);
|
|
const bLoc = loc.slice(-bRank);
|
|
bBroadcastDims.forEach(d => bLoc[d] = 0);
|
|
const bIndex = tf.util.locToIndex(bLoc, bRank, bStrides);
|
|
result[i] = op(aVals[aIndex], bVals[bIndex]);
|
|
}
|
|
}
|
|
return [result, newShape];
|
|
};
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const addImpl = createSimpleBinaryKernelImpl(((a, b) => a + b));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function bincountImpl(xVals, weightsVals, weightsDtype, weightsShape, size) {
|
|
const weightsSize = tf.util.sizeFromShape(weightsShape);
|
|
const outVals = tf.util.makeZerosTypedArray(size, weightsDtype);
|
|
for (let i = 0; i < xVals.length; i++) {
|
|
const value = xVals[i];
|
|
if (value < 0) {
|
|
throw new Error('Input x must be non-negative!');
|
|
}
|
|
if (value >= size) {
|
|
continue;
|
|
}
|
|
if (weightsSize > 0) {
|
|
outVals[value] += weightsVals[i];
|
|
}
|
|
else {
|
|
outVals[value] += 1;
|
|
}
|
|
}
|
|
return outVals;
|
|
}
|
|
function bincountReduceImpl(xBuf, weightsBuf, size, binaryOutput = false) {
|
|
const numRows = xBuf.shape[0];
|
|
const numCols = xBuf.shape[1];
|
|
const outBuf = tf.buffer([numRows, size], weightsBuf.dtype);
|
|
for (let i = 0; i < numRows; i++) {
|
|
for (let j = 0; j < numCols; j++) {
|
|
const value = xBuf.get(i, j);
|
|
if (value < 0) {
|
|
throw new Error('Input x must be non-negative!');
|
|
}
|
|
if (value >= size) {
|
|
continue;
|
|
}
|
|
if (binaryOutput) {
|
|
outBuf.set(1, i, value);
|
|
}
|
|
else {
|
|
if (weightsBuf.size > 0) {
|
|
outBuf.set(outBuf.get(i, value) + weightsBuf.get(i, j), i, value);
|
|
}
|
|
else {
|
|
outBuf.set(outBuf.get(i, value) + 1, i, value);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* Template that creates implementation for unary op.
|
|
*/
|
|
function createSimpleUnaryImpl(op) {
|
|
return (values, dtype, attrs) => {
|
|
const newValues = tf.util.getTypedArrayFromDType(dtype, values.length);
|
|
for (let i = 0; i < values.length; ++i) {
|
|
newValues[i] = op(values[i], attrs);
|
|
}
|
|
return newValues;
|
|
};
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const ceilImpl = createSimpleUnaryImpl((xi) => Math.ceil(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function concatImpl(inputs, outShape, dtype, simplyConcat) {
|
|
const outVals = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(outShape));
|
|
if (simplyConcat && dtype !== 'string') {
|
|
// Use built-in TypedArray.set() method for speed.
|
|
let offset = 0;
|
|
inputs.forEach(input => {
|
|
const size = tf.util.sizeFromShape(input.shape);
|
|
outVals.set(input.vals, offset);
|
|
offset += size;
|
|
});
|
|
}
|
|
else {
|
|
let colOffset = 0;
|
|
inputs.forEach(input => {
|
|
const decodedData = dtype === 'string' ?
|
|
tf.backend_util.fromUint8ToStringArray(input.vals) :
|
|
input.vals;
|
|
let tIdx = 0;
|
|
for (let row = 0; row < input.shape[0]; ++row) {
|
|
const resIdx = row * outShape[1] + colOffset;
|
|
for (let col = 0; col < input.shape[1]; ++col) {
|
|
outVals[resIdx + col] = decodedData[tIdx++];
|
|
}
|
|
}
|
|
colOffset += input.shape[1];
|
|
});
|
|
}
|
|
return outVals;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const equalImpl = createSimpleBinaryKernelImpl((a, b) => (a === b) ? 1 : 0);
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const expImpl = createSimpleUnaryImpl((xi) => Math.exp(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const expm1Impl = createSimpleUnaryImpl((xi) => Math.expm1(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const floorImpl = createSimpleUnaryImpl((xi) => Math.floor(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function gatherNdImpl(indicesData, paramsBuf, dtype, numSlices, sliceRank, sliceSize, strides, paramsShape, paramsSize) {
|
|
const outBuf = tf.buffer([numSlices, sliceSize], dtype);
|
|
for (let i = 0; i < numSlices; i++) {
|
|
const index = [];
|
|
let flattenIndex = 0;
|
|
for (let j = 0; j < sliceRank; j++) {
|
|
const dim = indicesData[i * sliceRank + j];
|
|
flattenIndex += dim * strides[j];
|
|
index.push(dim);
|
|
}
|
|
if (flattenIndex < 0 || flattenIndex >= paramsSize / sliceSize) {
|
|
throw new Error(`Invalid indices: ${index} does not index into ${paramsShape}`);
|
|
}
|
|
for (let k = 0; k < sliceSize; k++) {
|
|
outBuf.values[i * sliceSize + k] =
|
|
paramsBuf.get(...paramsBuf.indexToLoc(flattenIndex * sliceSize + k));
|
|
}
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function gatherV2Impl(xBuf, indicesBuf, flattenOutputShape) {
|
|
const outBuf = tf.buffer(flattenOutputShape, xBuf.dtype);
|
|
for (let i = 0; i < outBuf.size; ++i) {
|
|
const newLoc = outBuf.indexToLoc(i);
|
|
const originalLoc = newLoc.slice();
|
|
const batchIdx = originalLoc[0];
|
|
const indicesIdx = originalLoc[2];
|
|
const indicesIndex = indicesBuf.locToIndex([batchIdx, indicesIdx]);
|
|
originalLoc[2] = indicesBuf.values[indicesIndex];
|
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
|
outBuf.values[i] = xBuf.values[originalIndex];
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const greaterImpl = createSimpleBinaryKernelImpl((a, b) => (a > b) ? 1 : 0);
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const greaterEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a >= b) ? 1 : 0);
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const lessImpl = createSimpleBinaryKernelImpl((a, b) => (a < b) ? 1 : 0);
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const lessEqualImpl = createSimpleBinaryKernelImpl((a, b) => (a <= b) ? 1 : 0);
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function linSpaceImpl(start, stop, num) {
|
|
const step = (stop - start) / (num - 1);
|
|
const values = tf.util.makeZerosTypedArray(num, 'float32');
|
|
values[0] = start;
|
|
for (let i = 1; i < values.length; i++) {
|
|
values[i] = values[i - 1] + step;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const logImpl = createSimpleUnaryImpl((xi) => Math.log(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxImpl(aVals, reduceSize, outShape, dtype) {
|
|
const vals = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(outShape));
|
|
for (let i = 0; i < vals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let max = aVals[offset];
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
const value = aVals[offset + j];
|
|
if (Number.isNaN(value) ||
|
|
value > max) { // comparison with NaN always return false
|
|
max = value;
|
|
}
|
|
}
|
|
vals[i] = max;
|
|
}
|
|
return vals;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const maximumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.max(aValue, bValue)));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const minimumImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => Math.min(aValue, bValue)));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const multiplyImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue * bValue));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function negImpl(xVals, xShape, xDtype) {
|
|
const minusOne = tf.util.createScalarValue(-1, xDtype);
|
|
return multiplyImpl([], xShape, minusOne, xVals, xDtype);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const notEqualImpl = createSimpleBinaryKernelImpl(((a, b) => (a !== b) ? 1 : 0));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function transposeImpl(xVals, xShape, dtype, perm, newShape) {
|
|
const xRank = xShape.length;
|
|
const xSize = tf.util.sizeFromShape(xShape);
|
|
const xStrides = tf.util.computeStrides(xShape);
|
|
const newStrides = tf.util.computeStrides(newShape);
|
|
const result = tf.util.getTypedArrayFromDType(dtype, tf.util.sizeFromShape(newShape));
|
|
for (let i = 0; i < xSize; ++i) {
|
|
const loc = tf.util.indexToLoc(i, xRank, xStrides);
|
|
// Permute location.
|
|
const newLoc = new Array(loc.length);
|
|
for (let i = 0; i < newLoc.length; i++) {
|
|
newLoc[i] = loc[perm[i]];
|
|
}
|
|
const newIndex = tf.util.locToIndex(newLoc, xRank, newStrides);
|
|
result[newIndex] = xVals[i];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function prodImpl(xShape, xDtype, xVals, reductionAxes) {
|
|
const [outShape, reduceShape] = tf.backend_util.computeOutAndReduceShapes(xShape, reductionAxes);
|
|
const outDtype = tf.upcastType(xDtype, 'int32');
|
|
const outVals = tf.util.makeZerosTypedArray(tf.util.sizeFromShape(outShape), outDtype);
|
|
const reduceSize = tf.util.sizeFromShape(reduceShape);
|
|
for (let i = 0; i < outVals.length; ++i) {
|
|
const offset = i * reduceSize;
|
|
let prod = 1;
|
|
for (let j = 0; j < reduceSize; ++j) {
|
|
prod *= xVals[offset + j];
|
|
}
|
|
outVals[i] = prod;
|
|
}
|
|
return { outVals, outShape, outDtype };
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function rangeImpl(start, stop, step, dtype) {
|
|
const sameStartStop = start === stop;
|
|
const increasingRangeNegativeStep = start < stop && step < 0;
|
|
const decreasingRangePositiveStep = stop < start && step > 1;
|
|
if (sameStartStop || increasingRangeNegativeStep ||
|
|
decreasingRangePositiveStep) {
|
|
return tf.util.makeZerosTypedArray(0, dtype);
|
|
}
|
|
const numElements = Math.abs(Math.ceil((stop - start) / step));
|
|
const values = tf.util.makeZerosTypedArray(numElements, dtype);
|
|
if (stop < start && step === 1) {
|
|
// Auto adjust the step's sign if it hasn't been set
|
|
// (or was set to 1)
|
|
step = -1;
|
|
}
|
|
values[0] = start;
|
|
for (let i = 1; i < values.length; i++) {
|
|
values[i] = values[i - 1] + step;
|
|
}
|
|
return values;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const rsqrtImpl = createSimpleUnaryImpl((xi) => 1 / Math.sqrt(xi));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sliceImpl(vals, begin, size, shape, dtype) {
|
|
const isContinous = tf.slice_util.isSliceContinous(shape, begin, size);
|
|
const length = tf.util.sizeFromShape(size);
|
|
const xStrides = tf.util.computeStrides(shape);
|
|
if (isContinous) {
|
|
const flatOffset = tf.slice_util.computeFlatOffset(begin, xStrides);
|
|
if (dtype === 'string') {
|
|
return vals.slice(flatOffset, flatOffset + length);
|
|
}
|
|
return vals.subarray(flatOffset, flatOffset + length);
|
|
}
|
|
const decodedData = dtype === 'string' ?
|
|
tf.backend_util.fromUint8ToStringArray(vals) :
|
|
vals;
|
|
const inBuf = tf.buffer(shape, dtype, decodedData);
|
|
const outBuf = tf.buffer(size, dtype);
|
|
for (let i = 0; i < outBuf.size; ++i) {
|
|
const outLoc = outBuf.indexToLoc(i);
|
|
const inLoc = outLoc.map((idx, j) => idx + begin[j]);
|
|
outBuf.set(inBuf.get(...inLoc), ...outLoc);
|
|
}
|
|
if (dtype === 'string') {
|
|
return tf.backend_util.fromStringArrayToUint8(outBuf.values);
|
|
}
|
|
return outBuf.values;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseFillEmptyRowsImpl(indices, indicesShape, indicesDType, values, valuesDType, denseShape, defaultValue) {
|
|
const indicesCount = indicesShape[0];
|
|
const denseRows = denseShape[0];
|
|
const emptyRowIndicator = new Array(denseRows);
|
|
const reverseIndexMap = new Array(indicesCount);
|
|
const rank = indicesShape[1];
|
|
if (denseRows === 0) {
|
|
if (indicesCount !== 0) {
|
|
throw new Error(`Received SparseTensor with denseShape[0] = 0 but
|
|
indices.shape[0] = ${indicesCount}`);
|
|
}
|
|
const outputIndices = tf.util.getArrayFromDType(indicesDType, 0);
|
|
const outputValues = tf.util.getArrayFromDType(valuesDType, 0);
|
|
return [
|
|
outputIndices, [0, rank], outputValues, emptyRowIndicator, reverseIndexMap
|
|
];
|
|
}
|
|
let rowsAreOrdered = true;
|
|
let lastIndicesRow = 0;
|
|
const csrOffset = new Array(denseRows).fill(0);
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
// indices is a 2d tensor with shape of [N, rank]
|
|
const row = indices[i * rank];
|
|
if (row < 0) {
|
|
throw new Error(`indices(${i}, 0) is invalid: ${row} < 0`);
|
|
}
|
|
if (row >= denseRows) {
|
|
throw new Error(`indices(${i}, 0) is invalid: ${row} >= ${denseRows}`);
|
|
}
|
|
++csrOffset[row];
|
|
rowsAreOrdered = rowsAreOrdered && (row >= lastIndicesRow);
|
|
lastIndicesRow = row;
|
|
}
|
|
let allRowsFull = true;
|
|
for (let row = 0; row < denseRows; ++row) {
|
|
// csrOffset here describes the number of elements in this dense row
|
|
const rowEmpty = (csrOffset[row] === 0);
|
|
emptyRowIndicator[row] = rowEmpty;
|
|
allRowsFull = allRowsFull && !rowEmpty;
|
|
// In filled version, each row has at least one element.
|
|
csrOffset[row] = Math.max(csrOffset[row], 1);
|
|
// Update csrOffset to represent the number of elements up to and
|
|
// including denseRows + 1:
|
|
// csrOffset[0] == #{elements of row 0}
|
|
// csrOffset[1] == #{elements of row 1} + #{elements of row 0}
|
|
// ..
|
|
// csrOffset[i] == starting index for elements in row i + 1.
|
|
if (row > 0) {
|
|
csrOffset[row] += csrOffset[row - 1];
|
|
}
|
|
}
|
|
if (allRowsFull && rowsAreOrdered) {
|
|
const outputIndices = indices;
|
|
const outputValues = values;
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
reverseIndexMap[i] = i;
|
|
}
|
|
return [
|
|
outputIndices, [indicesCount, rank], outputValues, emptyRowIndicator,
|
|
reverseIndexMap
|
|
];
|
|
}
|
|
else {
|
|
const fullIndicesCount = csrOffset[denseRows - 1];
|
|
const outputIndices = tf.util.getArrayFromDType(indicesDType, fullIndicesCount * rank);
|
|
const outputValues = tf.util.getArrayFromDType(valuesDType, fullIndicesCount);
|
|
const filledCount = new Array(denseRows).fill(0);
|
|
// Fill in values for rows that are not missing
|
|
for (let i = 0; i < indicesCount; ++i) {
|
|
// indices is a 2d tensor with shape of [N, rank]
|
|
const row = indices[i * rank];
|
|
const offset = filledCount[row];
|
|
const outputI = ((row === 0) ? 0 : csrOffset[row - 1]) + offset;
|
|
filledCount[row]++; // Increment the filled count for this row.
|
|
for (let j = 0; j < rank; ++j) {
|
|
// indices and outputIndices are 2d tensors with shape of [N, rank]
|
|
outputIndices[outputI * rank + j] = indices[i * rank + j];
|
|
}
|
|
outputValues[outputI] = values[i];
|
|
// We'll need this reverse index map to backprop correctly.
|
|
reverseIndexMap[i] = outputI;
|
|
}
|
|
// Fill in values for rows that are missing
|
|
for (let row = 0; row < denseRows; ++row) {
|
|
const rowCount = filledCount[row];
|
|
if (rowCount === 0) { // We haven't filled this row
|
|
const startingIndex = (row === 0) ? 0 : csrOffset[row - 1];
|
|
// Remaining index values were set to zero already.
|
|
// Just need to set the row index in the right location.
|
|
// outputIndices is a 2d tensor with shape of [N, rank]
|
|
outputIndices[startingIndex * rank + 0] = row;
|
|
for (let col = 1; col < rank; ++col) {
|
|
outputIndices[startingIndex * rank + col] = 0;
|
|
}
|
|
outputValues[startingIndex] = defaultValue;
|
|
}
|
|
}
|
|
return [
|
|
outputIndices, [fullIndicesCount, rank], outputValues, emptyRowIndicator,
|
|
reverseIndexMap
|
|
];
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseReshapeImpl(inputIndices, inputIndicesShape, inputDType, inputShape, targetShape) {
|
|
const denseSize = tf.util.sizeFromShape(inputShape);
|
|
const nnz = inputIndicesShape[0];
|
|
const outputRank = targetShape.length;
|
|
// Compute the output shape. Determine product of specified dimensions, and
|
|
// find the index of the unspecified one.
|
|
const outputShape = [];
|
|
let product = 1;
|
|
let unknownIndex = -1;
|
|
for (let d = 0; d < outputRank; ++d) {
|
|
const size = targetShape[d];
|
|
if (size === -1) {
|
|
if (unknownIndex !== -1) {
|
|
throw new Error(`only one output dimension may be -1, not both ${unknownIndex} and ${d}`);
|
|
}
|
|
unknownIndex = d;
|
|
outputShape.push(1);
|
|
}
|
|
else {
|
|
if (size < 0) {
|
|
throw new Error(`size ${d} must be non-negative, not ${size}`);
|
|
}
|
|
product *= size;
|
|
outputShape.push(size);
|
|
}
|
|
}
|
|
if (unknownIndex !== -1) {
|
|
if (product <= 0) {
|
|
throw new Error('reshape cannot infer the missing ' +
|
|
'input size for an empty tensor unless all ' +
|
|
'specified input sizes are non-zero');
|
|
}
|
|
const missing = Math.trunc(denseSize / product);
|
|
if (product * missing !== denseSize) {
|
|
throw new Error(`Input to reshape is a SparseTensor with ${denseSize}
|
|
dense values, but the requested shape requires a multiple of ${product}. inputShape=${inputShape} outputShape= ${outputShape}`);
|
|
}
|
|
outputShape[unknownIndex] = missing;
|
|
}
|
|
const outputSize = tf.util.sizeFromShape(outputShape);
|
|
if (outputSize !== denseSize) {
|
|
throw new Error(`Input to reshape is a tensor with ${denseSize} dense values, but the requested shape has ${outputSize}. inputShape=${inputShape} outputShape=${outputShape}`);
|
|
}
|
|
const inputRank = inputShape.length;
|
|
const inputStrides = [];
|
|
if (inputRank > 0) {
|
|
inputStrides[inputRank - 1] = 1;
|
|
for (let d = inputRank - 2; d >= 0; --d) {
|
|
inputStrides[d] = inputStrides[d + 1] * inputShape[d + 1];
|
|
}
|
|
}
|
|
const outputStrides = [];
|
|
if (outputRank > 0) {
|
|
outputStrides[outputRank - 1] = 1;
|
|
for (let d = outputRank - 2; d >= 0; --d) {
|
|
outputStrides[d] = outputStrides[d + 1] * outputShape[d + 1];
|
|
}
|
|
}
|
|
const newIndices = tf.util.getArrayFromDType(inputDType, nnz * outputRank);
|
|
for (let i = 0; i < nnz; ++i) {
|
|
let id = 0;
|
|
for (let j = 0; j < inputRank; ++j) {
|
|
// inputIndices is a 2d tensor with shape of [nnz, inputRank]
|
|
id += inputIndices[i * inputRank + j] * inputStrides[j];
|
|
}
|
|
for (let j = 0; j < outputRank; ++j) {
|
|
// newIndices is a 2d tensor with shape of [nnz, outputRank]
|
|
newIndices[i * outputRank + j] = Math.trunc(id / outputStrides[j]);
|
|
id %= outputStrides[j];
|
|
}
|
|
}
|
|
return [newIndices, [nnz, outputRank], outputShape];
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseSegmentReductionImpl(input, inputShape, inputDType, indices, segmentIds, isMean = false, defaultValue = 0) {
|
|
const numIndices = indices.length;
|
|
if (numIndices !== segmentIds.length) {
|
|
throw new Error(`segmentIds and indices should have same size.`);
|
|
}
|
|
// Flatten the array to two dimensions
|
|
const inputFlat = [inputShape[0], input.length / inputShape[0]];
|
|
const numCol = inputFlat[1];
|
|
// Note that the current implementation assumes that segmentIds values are
|
|
// sorted.
|
|
const lastSegmentIdPlusOne = numIndices > 0 ? segmentIds[numIndices - 1] + 1 : 0;
|
|
const outputRows = lastSegmentIdPlusOne;
|
|
if (outputRows < 0) {
|
|
throw new Error(`segment ids must be >= 0`);
|
|
}
|
|
const outputShape = inputShape.slice();
|
|
outputShape[0] = outputRows;
|
|
const outputLength = outputShape.reduce((product, value) => product * value, 1);
|
|
// Output array is initialized with the value 0 by default.
|
|
const output = tf.util.getArrayFromDType(inputDType, outputLength);
|
|
// Note that we do not initialize the output buffer with a default value, so
|
|
// we need to explicitly set missing indices to the default value.
|
|
if (numIndices === 0) {
|
|
if (outputRows > 0) {
|
|
output.fill(defaultValue);
|
|
}
|
|
return [output, outputShape];
|
|
}
|
|
if (outputRows <= 0) {
|
|
throw new Error(`segment ids must be >= 0`);
|
|
}
|
|
let start = 0, end = 1;
|
|
// Index from which the output is not initialized.
|
|
let uninitializedIndex = 0;
|
|
let outIndex = segmentIds[start];
|
|
while (true) {
|
|
// We initialize nextIndex to 0 to avoid may be uninitialized warning
|
|
let nextIndex = 0;
|
|
if (end < numIndices) {
|
|
nextIndex = segmentIds[end];
|
|
if (outIndex === nextIndex) {
|
|
++end;
|
|
continue;
|
|
}
|
|
// We have a new segment here. Verify that the segment ids are growing.
|
|
if (outIndex >= nextIndex) {
|
|
throw new Error(`segment ids are not increasing`);
|
|
}
|
|
}
|
|
if (outIndex < 0 || outIndex >= outputRows) {
|
|
throw new Error(`Segment id ${outIndex} out of range [0, ${outputRows}), possibly because segmentIds input is not sorted.`);
|
|
}
|
|
// If there is a gap between two indices, we need to set that gap to the
|
|
// default value.
|
|
if (outIndex > uninitializedIndex) {
|
|
output.fill(defaultValue, uninitializedIndex * numCol, outIndex * numCol);
|
|
}
|
|
for (let i = start; i < end; ++i) {
|
|
const index = indices[i];
|
|
if (index < 0 || index >= inputFlat[0]) {
|
|
throw new Error(`Bad: indices[${i}] == ${indices[i]} out of range [0, ${inputFlat[0]})`);
|
|
}
|
|
for (let j = 0; j < numCol; j++) {
|
|
output[outIndex * numCol + j] += input[index * numCol + j];
|
|
}
|
|
}
|
|
if (isMean) {
|
|
for (let j = 0; j < numCol; j++) {
|
|
output[outIndex * numCol + j] /= end - start;
|
|
}
|
|
}
|
|
start = end;
|
|
++end;
|
|
uninitializedIndex = outIndex + 1;
|
|
outIndex = nextIndex;
|
|
if (end > numIndices) {
|
|
break;
|
|
}
|
|
}
|
|
// Fill the gap at the end with the default value.
|
|
if (uninitializedIndex < outputRows) {
|
|
output.fill(defaultValue, uninitializedIndex * numCol, outputRows * numCol);
|
|
}
|
|
return [output, outputShape];
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stridedSliceImpl(outShape, xBuf, strides, begin) {
|
|
const outBuf = tf.buffer(outShape, xBuf.dtype);
|
|
for (let i = 0; i < outBuf.size; i++) {
|
|
const loc = outBuf.indexToLoc(i);
|
|
const newLoc = new Array(loc.length);
|
|
for (let j = 0; j < newLoc.length; j++) {
|
|
newLoc[j] = loc[j] * strides[j] + begin[j];
|
|
}
|
|
outBuf.set(xBuf.get(...newLoc), ...loc);
|
|
}
|
|
return outBuf;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* The StringNGramsOp class creates ngrams from ragged string data.
|
|
* The constructor contains all attributes related to the operation such as
|
|
* padding widths and strings, and the compute function can be used to
|
|
* compute the ngrams for different ragged tensor inputs.
|
|
*/
|
|
class StringNGramsOp {
|
|
constructor(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
|
this.separator = tf.util.encodeString(separator);
|
|
this.nGramWidths = nGramWidths;
|
|
this.leftPad = tf.util.encodeString(leftPad);
|
|
this.rightPad = tf.util.encodeString(rightPad);
|
|
this.padWidth = padWidth;
|
|
this.preserveShort = preserveShortSequences;
|
|
}
|
|
getPadWidth(nGramWidth) {
|
|
// Ngrams can be padded with either a fixed pad width or a dynamic pad
|
|
// width depending on the 'padWidth' arg, but in no case should the padding
|
|
// ever be wider than 'nGramWidth' - 1.
|
|
return Math.min(this.padWidth < 0 ? nGramWidth - 1 : this.padWidth, nGramWidth - 1);
|
|
}
|
|
getNumNGrams(length, nGramWidth) {
|
|
const padWidth = this.getPadWidth(nGramWidth);
|
|
return Math.max(0, ((length + 2 * padWidth) - nGramWidth) + 1);
|
|
}
|
|
createNGrams(data, splitIndex, output, outputStartIndex, numNGrams, nGramWidth) {
|
|
for (let nGramIndex = 0; nGramIndex < numNGrams; ++nGramIndex) {
|
|
const padWidth = this.getPadWidth(nGramWidth);
|
|
const leftPadding = Math.max(0, padWidth - nGramIndex);
|
|
const rightPadding = Math.max(0, padWidth - (numNGrams - (nGramIndex + 1)));
|
|
const numTokens = nGramWidth - (leftPadding + rightPadding);
|
|
const dataStartIndex = splitIndex + (leftPadding > 0 ? 0 : nGramIndex - padWidth);
|
|
// Calculate the total expected size of the nGram so we can reserve the
|
|
// correct amount of space in the string.
|
|
let nGramSize = 0;
|
|
// Size of the left padding.
|
|
nGramSize += leftPadding * this.leftPad.length;
|
|
// Size of the tokens.
|
|
for (let n = 0; n < numTokens; ++n) {
|
|
nGramSize += data[dataStartIndex + n].length;
|
|
}
|
|
// Size of the right padding.
|
|
nGramSize += rightPadding * this.rightPad.length;
|
|
// Size of the separators.
|
|
const numSeparators = leftPadding + rightPadding + numTokens - 1;
|
|
nGramSize += numSeparators * this.separator.length;
|
|
// Build the nGram.
|
|
output[outputStartIndex + nGramIndex] = new Uint8Array(nGramSize);
|
|
const nGram = output[outputStartIndex + nGramIndex];
|
|
let nextNGramIndex = 0;
|
|
const appendToNGram = (str) => str.forEach((value) => nGram[nextNGramIndex++] = value);
|
|
for (let n = 0; n < leftPadding; ++n) {
|
|
appendToNGram(this.leftPad);
|
|
appendToNGram(this.separator);
|
|
}
|
|
// Only output first numTokens - 1 pairs of data and separator
|
|
for (let n = 0; n < numTokens - 1; ++n) {
|
|
appendToNGram(data[dataStartIndex + n]);
|
|
appendToNGram(this.separator);
|
|
}
|
|
// Handle case when there are no tokens or no right padding as these
|
|
// can result in consecutive separators.
|
|
if (numTokens > 0) {
|
|
// If we have tokens, then output last and then pair each separator
|
|
// with the right padding that follows, to ensure nGram ends either with
|
|
// the token or with the right pad.
|
|
appendToNGram(data[dataStartIndex + numTokens - 1]);
|
|
for (let n = 0; n < rightPadding; ++n) {
|
|
appendToNGram(this.separator);
|
|
appendToNGram(this.rightPad);
|
|
}
|
|
}
|
|
else {
|
|
// If we don't have tokens, then the last item inserted into the nGram
|
|
// has been the separator from the left padding loop above. Hence,
|
|
// output right pad and separator and make sure to finish with a
|
|
// padding, not a separator.
|
|
for (let n = 0; n < rightPadding - 1; ++n) {
|
|
appendToNGram(this.rightPad);
|
|
appendToNGram(this.separator);
|
|
}
|
|
appendToNGram(this.rightPad);
|
|
}
|
|
}
|
|
}
|
|
// Data and splits together form the definition of the ragged tensor,
|
|
// where data is 1 dimensional and contains the values of the tensor
|
|
// and splits denotes the indices at which each row starts.
|
|
compute(data, splits) {
|
|
// Validate that the splits are valid indices into data, only if there are
|
|
// splits specified.
|
|
const inputDataSize = data.length;
|
|
const splitsSize = splits.length;
|
|
if (splitsSize > 0) {
|
|
let prevSplit = splits[0];
|
|
if (prevSplit !== 0) {
|
|
throw new Error(`First split value must be 0, got ${prevSplit}`);
|
|
}
|
|
for (let i = 1; i < splitsSize; ++i) {
|
|
let validSplits = splits[i] >= prevSplit;
|
|
validSplits = validSplits && (splits[i] <= inputDataSize);
|
|
if (!validSplits) {
|
|
throw new Error(`Invalid split value ${splits[i]}, must be in [${prevSplit}, ${inputDataSize}]`);
|
|
}
|
|
prevSplit = splits[i];
|
|
}
|
|
if (prevSplit !== inputDataSize) {
|
|
throw new Error(`Last split value must be data size. Expected ${inputDataSize}, got ${prevSplit}`);
|
|
}
|
|
}
|
|
const numBatchItems = splitsSize - 1;
|
|
const nGramsSplits = tf.util.getArrayFromDType('int32', splitsSize);
|
|
// If there is no data or size, return an empty ragged tensor.
|
|
if (inputDataSize === 0 || splitsSize === 0) {
|
|
const empty = new Array(inputDataSize);
|
|
for (let i = 0; i <= numBatchItems; ++i) {
|
|
nGramsSplits[i] = 0;
|
|
}
|
|
return [empty, nGramsSplits];
|
|
}
|
|
nGramsSplits[0] = 0;
|
|
for (let i = 1; i <= numBatchItems; ++i) {
|
|
const length = splits[i] - splits[i - 1];
|
|
let numNGrams = 0;
|
|
this.nGramWidths.forEach((nGramWidth) => {
|
|
numNGrams += this.getNumNGrams(length, nGramWidth);
|
|
});
|
|
if (this.preserveShort && length > 0 && numNGrams === 0) {
|
|
numNGrams = 1;
|
|
}
|
|
nGramsSplits[i] = nGramsSplits[i - 1] + numNGrams;
|
|
}
|
|
const nGrams = new Array(nGramsSplits[numBatchItems]);
|
|
for (let i = 0; i < numBatchItems; ++i) {
|
|
const splitIndex = splits[i];
|
|
let outputStartIdx = nGramsSplits[i];
|
|
this.nGramWidths.forEach((nGramWidth) => {
|
|
const length = splits[i + 1] - splits[i];
|
|
const numNGrams = this.getNumNGrams(length, nGramWidth);
|
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
|
outputStartIdx += numNGrams;
|
|
});
|
|
// If we're preserving short sequences, check to see if no sequence was
|
|
// generated by comparing the current output start idx to the original
|
|
// one (nGramSplitsdata). If no ngrams were generated, then they will
|
|
// be equal (since we increment outputStartIdx by numNGrams every
|
|
// time we create a set of ngrams.)
|
|
if (this.preserveShort && outputStartIdx === nGramsSplits[i]) {
|
|
const dataLength = splits[i + 1] - splits[i];
|
|
// One legitimate reason to not have any ngrams when this.preserveShort
|
|
// is true is if the sequence itself is empty. In that case, move on.
|
|
if (dataLength === 0) {
|
|
continue;
|
|
}
|
|
// We don't have to worry about dynamic padding sizes here: if padding
|
|
// was dynamic, every sequence would have had sufficient padding to
|
|
// generate at least one nGram.
|
|
const nGramWidth = dataLength + 2 * this.padWidth;
|
|
const numNGrams = 1;
|
|
this.createNGrams(data, splitIndex, nGrams, outputStartIdx, numNGrams, nGramWidth);
|
|
}
|
|
}
|
|
return [nGrams, nGramsSplits];
|
|
}
|
|
}
|
|
function stringNGramsImpl(data, dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences) {
|
|
return new StringNGramsOp(separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences)
|
|
.compute(data, dataSplits);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function split(str, delimiters, skipEmpty) {
|
|
if (!str.length) {
|
|
return [];
|
|
}
|
|
// When the delimiter is empty, the input is split into individual characters.
|
|
if (delimiters.length === 0) {
|
|
const result = new Array(str.length);
|
|
for (let i = 0; i < str.length; ++i) {
|
|
result[i] = str.subarray(i, i + 1);
|
|
}
|
|
return result;
|
|
}
|
|
// When there is one delimiter, the input is split only at that delimiter.
|
|
if (delimiters.length === 1) {
|
|
const delimiter = delimiters[0];
|
|
const result = [];
|
|
let f = str.indexOf(delimiter);
|
|
while (f !== -1) {
|
|
const token = str.subarray(0, f);
|
|
if (!skipEmpty || token.length !== 0) {
|
|
result.push(token);
|
|
}
|
|
str = str.subarray(f + 1);
|
|
f = str.indexOf(delimiter);
|
|
}
|
|
if (!skipEmpty || str.length !== 0) {
|
|
result.push(str);
|
|
}
|
|
return result;
|
|
}
|
|
// When there are multiple delimiters, the input is split at every instance
|
|
// one of the delimiters appears.
|
|
const result = [];
|
|
let tokenStart = 0;
|
|
for (let i = 0; i < str.length + 1; i++) {
|
|
if ((i === str.length) || (delimiters.indexOf(str[i]) !== -1)) {
|
|
const token = str.subarray(tokenStart, i);
|
|
if (!skipEmpty || token.length !== 0) {
|
|
result.push(token);
|
|
}
|
|
tokenStart = i + 1;
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
function stringSplitImpl(input, delimiter, skipEmpty) {
|
|
const batchSize = input.length;
|
|
// Empty delimiter means split the input character by character.
|
|
const tokens = [];
|
|
let outputSize = 0;
|
|
let maxNumEntries = 0;
|
|
const numIndices = new Array(batchSize);
|
|
for (let i = 0; i < batchSize; ++i) {
|
|
const parts = split(input[i], delimiter, skipEmpty);
|
|
const nEntries = parts.length;
|
|
numIndices[i] = nEntries;
|
|
outputSize += nEntries;
|
|
maxNumEntries = Math.max(maxNumEntries, nEntries);
|
|
tokens.push(...parts);
|
|
}
|
|
const indices = tf.util.getArrayFromDType('int32', outputSize * 2);
|
|
const values = new Array(outputSize);
|
|
const shape = [batchSize, maxNumEntries];
|
|
let c = 0;
|
|
for (let i = 0; i < batchSize; ++i) {
|
|
for (let j = 0; j < numIndices[i]; ++j) {
|
|
// indices is a 2d tensor with shape of [outputSize, 2]
|
|
indices[c * 2] = i;
|
|
indices[c * 2 + 1] = j;
|
|
values[c] = tokens[c];
|
|
++c;
|
|
}
|
|
}
|
|
return [indices, values, shape];
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stringToHashBucketFastImpl(input, numBuckets) {
|
|
const output = tf.util.getArrayFromDType('int32', input.length);
|
|
for (let i = 0; i < input.length; ++i) {
|
|
output[i] =
|
|
tf.util.fingerPrint64(input[i]).modulo(numBuckets).getLowBitsUnsigned();
|
|
}
|
|
return output;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
const subImpl = createSimpleBinaryKernelImpl(((aValue, bValue) => aValue - bValue));
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* An implementation of the tile kernel shared between webgl and cpu for string
|
|
* tensors only.
|
|
*/
|
|
function tileImpl(xBuf, reps) {
|
|
const newShape = new Array(xBuf.rank);
|
|
for (let i = 0; i < newShape.length; i++) {
|
|
newShape[i] = xBuf.shape[i] * reps[i];
|
|
}
|
|
const result = tf.buffer(newShape, xBuf.dtype);
|
|
for (let i = 0; i < result.values.length; ++i) {
|
|
const newLoc = result.indexToLoc(i);
|
|
const originalLoc = new Array(xBuf.rank);
|
|
for (let j = 0; j < originalLoc.length; j++) {
|
|
originalLoc[j] = newLoc[j] % xBuf.shape[j];
|
|
}
|
|
const originalIndex = xBuf.locToIndex(originalLoc);
|
|
result.values[i] = xBuf.values[originalIndex];
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function topKImpl(x, xShape, xDtype, k, sorted) {
|
|
// Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim.
|
|
const lastDim = xShape[xShape.length - 1];
|
|
const [batch, size] = [x.length / lastDim, lastDim];
|
|
const allTopKVals = tf.util.getTypedArrayFromDType(xDtype, batch * k);
|
|
const allTopKIndices = tf.util.getTypedArrayFromDType('int32', batch * k);
|
|
for (let b = 0; b < batch; b++) {
|
|
const offset = b * size;
|
|
const vals = x.subarray(offset, offset + size);
|
|
const valAndInd = [];
|
|
for (let i = 0; i < vals.length; i++) {
|
|
valAndInd.push({ value: vals[i], index: i });
|
|
}
|
|
valAndInd.sort((a, b) => b.value - a.value);
|
|
const outOffset = b * k;
|
|
const topKVals = allTopKVals.subarray(outOffset, outOffset + k);
|
|
const topKIndices = allTopKIndices.subarray(outOffset, outOffset + k);
|
|
for (let i = 0; i < k; i++) {
|
|
topKVals[i] = valAndInd[i].value;
|
|
topKIndices[i] = valAndInd[i].index;
|
|
}
|
|
}
|
|
// Reshape back to the original input shape, except that the last
|
|
// dimension is k.
|
|
const outputShape = xShape.slice();
|
|
outputShape[outputShape.length - 1] = k;
|
|
return [
|
|
tf.buffer(outputShape, xDtype, allTopKVals),
|
|
tf.buffer(outputShape, 'int32', allTopKIndices)
|
|
];
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function uniqueImpl(values, axis, shape, dtype) {
|
|
// Normalize and validate axis.
|
|
const $axis = tf.util.parseAxisParam(axis, shape)[0];
|
|
// Calculate the new shape that is suitable for extracting data along the
|
|
// given axis.
|
|
//
|
|
// The rank is 3.
|
|
// The size of the 1st dimension is the size of all the axes < the given axis.
|
|
// The size of the 2nd dimension is the same as the size of the given axis.
|
|
// The size of the 3rd dimension is the size of all the axes > the given axis.
|
|
//
|
|
// For example, for a 4D tensor with shape=[2, 3, 5, 4] and axis=2, the
|
|
// newShape would be: [2*3, 5, 4].
|
|
//
|
|
// Note that this is not the final output shape. This will be the shape for an
|
|
// intermediate TensorBuffer (see inputBuffer below) to allow us to extract
|
|
// values along the given axis. To demonstrate how it works, consider the
|
|
// following example:
|
|
//
|
|
// Input: a 3D tensor, with shape [1, 2, 3]
|
|
// [
|
|
// [
|
|
// [1,2,3],
|
|
// [4,5,6]
|
|
// ]
|
|
// ]
|
|
// Axis: 2 (the last axis).
|
|
// Along axis 2, we expect to extract 3 tensors: [1,4], [2,5], [3,6].
|
|
//
|
|
// For this example, newShape would be: [2, 3, 1], where 2 is calculated from
|
|
// 1*2. The re-shaped data would look like:
|
|
//
|
|
// [
|
|
// [
|
|
// [1], [2], [3]
|
|
// ],
|
|
// [
|
|
// [4], [5], [6]
|
|
// ]
|
|
// ]
|
|
//
|
|
// Then, we can construct a 3-level nested loop by the following dimension
|
|
// order to extract the values along the axis (dimension1):
|
|
// i: dimension1 // 0,1,2 (newShape[1])
|
|
// m: dimension0 // 0,1 (newShape[0])
|
|
// n: dimension2 // 0 (newShape[2])
|
|
//
|
|
// m, i, n
|
|
// ---------
|
|
// Iteration 0: data at [0, 0, 0] => "1"
|
|
// Iteration 1: data at [1, 0, 0] => "4"
|
|
// We got [1,4].
|
|
// Iteration 2: data at [0, 1, 0] => "2"
|
|
// Iteration 3: data at [1, 1, 0] => "5"
|
|
// We got [2,5].
|
|
// Iteration 4: data at [0, 2, 0] => "3"
|
|
// Iteration 5: data at [1, 2, 0] => "6"
|
|
// We got [3,6].
|
|
const newShape = [1, shape[0], 1];
|
|
for (let i = 0; i < $axis; i++) {
|
|
newShape[0] *= shape[i];
|
|
}
|
|
newShape[1] = shape[$axis];
|
|
for (let i = $axis + 1; i < shape.length; i++) {
|
|
newShape[2] *= shape[i];
|
|
}
|
|
// A map from unique elements (their string representations) to their values
|
|
// in "indices" (below).
|
|
const uniqueElements = {};
|
|
// The indices of each unique element in the original tensor along the given
|
|
// axis. It is 1D and has the same size as the given axis.
|
|
const indices = new Int32Array(shape[$axis]);
|
|
// Create a buffer so we can easily extract value at a given location.
|
|
const inputBuffer = new tf.TensorBuffer(newShape, dtype, values);
|
|
// The indices along the given axis that have unique elements. This is a
|
|
// de-duped version of "indices" above.
|
|
const uniqueIndices = [];
|
|
const is1DTensor = newShape[0] === 1 && newShape[2] === 1;
|
|
for (let i = 0; i < shape[$axis]; i++) {
|
|
// Extract values along the axis.
|
|
let element;
|
|
if (is1DTensor) {
|
|
// Fast path for 1D tensor input.
|
|
element = values[i].toString();
|
|
}
|
|
else {
|
|
const axisValues = [];
|
|
for (let m = 0; m < newShape[0]; m++) {
|
|
for (let n = 0; n < newShape[2]; n++) {
|
|
axisValues.push(inputBuffer.get(m, i, n));
|
|
}
|
|
}
|
|
element = axisValues.join(',');
|
|
}
|
|
// Dedup and update various indices.
|
|
if (uniqueElements[element] !== undefined) {
|
|
indices[i] = uniqueElements[element];
|
|
}
|
|
else {
|
|
const uniqueIndex = Object.keys(uniqueElements).length;
|
|
uniqueElements[element] = uniqueIndex;
|
|
indices[i] = uniqueIndex;
|
|
uniqueIndices.push(i);
|
|
}
|
|
}
|
|
// Now we know where each of the unique elements are located along the axis
|
|
// (uniqueIndices). Extract them from input buffer and store them in the
|
|
// output buffer.
|
|
const outputTmpShape = newShape.slice();
|
|
outputTmpShape[1] = Object.keys(uniqueElements).length;
|
|
const outputBuffer = new tf.TensorBuffer(outputTmpShape, dtype);
|
|
uniqueIndices.forEach((uniqueElementIndex, i) => {
|
|
for (let m = 0; m < newShape[0]; m++) {
|
|
for (let n = 0; n < newShape[2]; n++) {
|
|
outputBuffer.set(inputBuffer.get(m, uniqueElementIndex, n), m, i, n);
|
|
}
|
|
}
|
|
});
|
|
// The output shape can be calculated from the input shape with the size of
|
|
// the given axis replaced by the number of unique elements along that axis.
|
|
const outputShape = shape.slice();
|
|
outputShape[$axis] = outputTmpShape[1];
|
|
return {
|
|
outputValues: outputBuffer.values,
|
|
outputShape,
|
|
indices,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var addImplCPU = addImpl, bincountImplCPU = bincountImpl, bincountReduceImplCPU = bincountReduceImpl, ceilImplCPU = ceilImpl, concatImplCPU = concatImpl, equalImplCPU = equalImpl, expImplCPU = expImpl, expm1ImplCPU = expm1Impl, floorImplCPU = floorImpl, gatherNdImplCPU = gatherNdImpl, gatherV2ImplCPU = gatherV2Impl, greaterImplCPU = greaterImpl, greaterEqualImplCPU = greaterEqualImpl, lessImplCPU = lessImpl, lessEqualImplCPU = lessEqualImpl, linSpaceImplCPU = linSpaceImpl, logImplCPU = logImpl, maxImplCPU = maxImpl, maximumImplCPU = maximumImpl, minimumImplCPU = minimumImpl, multiplyImplCPU = multiplyImpl, negImplCPU = negImpl, notEqualImplCPU = notEqualImpl, prodImplCPU = prodImpl, rangeImplCPU = rangeImpl, rsqrtImplCPU = rsqrtImpl, simpleAbsImplCPU = simpleAbsImpl, sliceImplCPU = sliceImpl, sparseFillEmptyRowsImplCPU = sparseFillEmptyRowsImpl, sparseReshapeImplCPU = sparseReshapeImpl, sparseSegmentReductionImplCPU = sparseSegmentReductionImpl, stridedSliceImplCPU = stridedSliceImpl, stringNGramsImplCPU = stringNGramsImpl, stringSplitImplCPU = stringSplitImpl, stringToHashBucketFastImplCPU = stringToHashBucketFastImpl, subImplCPU = subImpl, tileImplCPU = tileImpl, topKImplCPU = topKImpl, transposeImplCPU = transposeImpl, uniqueImplCPU = uniqueImpl;
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function getVecChannels(name, rank) {
|
|
return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) { return name + "." + d; });
|
|
}
|
|
function getChannels(name, rank) {
|
|
if (rank === 1) {
|
|
return [name];
|
|
}
|
|
return getVecChannels(name, rank);
|
|
}
|
|
function getSourceCoords(rank, dims) {
|
|
if (rank === 1) {
|
|
return 'rc';
|
|
}
|
|
var coords = '';
|
|
for (var i = 0; i < rank; i++) {
|
|
coords += dims[i];
|
|
if (i < rank - 1) {
|
|
coords += ',';
|
|
}
|
|
}
|
|
return coords;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PackProgram = /** @class */ (function () {
|
|
function PackProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
// Only input / output 3D tensors.
|
|
this.outputShape = outputShape;
|
|
var rank = outputShape.length;
|
|
if (rank === 0) {
|
|
this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n ";
|
|
}
|
|
else {
|
|
var channels = getChannels('rc', rank);
|
|
var dtype = getCoordsDataType(rank);
|
|
var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels);
|
|
var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels);
|
|
var output = getOutput(outputShape, channels);
|
|
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n ";
|
|
}
|
|
}
|
|
return PackProgram;
|
|
}());
|
|
function getSourceCoordsArr(rank, dims) {
|
|
var coords = [];
|
|
for (var row = 0; row <= 1; row++) {
|
|
for (var col = 0; col <= 1; col++) {
|
|
var coord = (row === 0 ? 'r' : 'rp1') + ", " + (col === 0 ? 'c' : 'cp1');
|
|
for (var d = 2; d < rank; d++) {
|
|
coord = dims[dims.length - 1 - d] + "," + coord;
|
|
}
|
|
coords.push(coord);
|
|
}
|
|
}
|
|
return coords;
|
|
}
|
|
function getOutOfBoundsCondition(rank, shape, dims) {
|
|
if (rank === 1) {
|
|
return "rc > " + shape[0];
|
|
}
|
|
var cond = '';
|
|
for (var i = rank - 2; i < rank; i++) {
|
|
cond += dims[i] + " >= " + shape[i];
|
|
if (i < rank - 1) {
|
|
cond += '||';
|
|
}
|
|
}
|
|
return cond;
|
|
}
|
|
function getSetup(rank, cols, rows, dims) {
|
|
if (rank === 1) {
|
|
return '';
|
|
}
|
|
var innerDims = dims.slice(-2);
|
|
return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n ";
|
|
}
|
|
function getOutput(shape, dims) {
|
|
var rank = shape.length;
|
|
var sourceCoords = getSourceCoordsArr(rank, dims);
|
|
if (rank === 1) {
|
|
return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0";
|
|
}
|
|
return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")";
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ReshapePackedProgram = /** @class */ (function () {
|
|
function ReshapePackedProgram(outputShape, inputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
var mainLoop = "";
|
|
for (var i = 0; i < 4; i++) {
|
|
var thisRC = "thisRC = rc;";
|
|
if (i % 2 === 1) {
|
|
thisRC += "thisRC.z += 1;";
|
|
}
|
|
if (i > 1) {
|
|
thisRC += "thisRC.y += 1;";
|
|
}
|
|
mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '') + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? '}' : '') + "\n ";
|
|
}
|
|
this.userCode = "\n " + getReshapedInputCoords(inputShape) + "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + outputShape[1] + ";\n int cols = " + outputShape[2] + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n ";
|
|
}
|
|
return ReshapePackedProgram;
|
|
}());
|
|
function getReshapedInputCoords(shape) {
|
|
var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape);
|
|
return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n ";
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TextureManager = /** @class */ (function () {
|
|
function TextureManager(gpgpu) {
|
|
this.gpgpu = gpgpu;
|
|
this.numUsedTextures = 0;
|
|
this.numFreeTextures = 0;
|
|
this._numBytesAllocated = 0;
|
|
this._numBytesFree = 0; // How many bytes that have been allocated
|
|
// are available for reuse.
|
|
this.freeTextures = {};
|
|
this.logEnabled = false;
|
|
this.usedTextures = {};
|
|
}
|
|
TextureManager.prototype.acquireTexture = function (shapeRC, usage, isPacked) {
|
|
var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked);
|
|
var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked);
|
|
if (!(shapeKey in this.freeTextures)) {
|
|
this.freeTextures[shapeKey] = [];
|
|
}
|
|
if (!(shapeKey in this.usedTextures)) {
|
|
this.usedTextures[shapeKey] = [];
|
|
}
|
|
var texBytes = computeBytes(shapeRC, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
|
|
if (this.freeTextures[shapeKey].length > 0) {
|
|
this.numFreeTextures--;
|
|
this.numUsedTextures++;
|
|
this._numBytesFree -= texBytes;
|
|
this.log();
|
|
var newTexture_1 = this.freeTextures[shapeKey].shift();
|
|
this.usedTextures[shapeKey].push(newTexture_1);
|
|
return newTexture_1;
|
|
}
|
|
var newTexture;
|
|
if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) {
|
|
newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) {
|
|
newTexture =
|
|
this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) {
|
|
newTexture =
|
|
this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) {
|
|
newTexture =
|
|
this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) {
|
|
newTexture =
|
|
this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]);
|
|
}
|
|
this.usedTextures[shapeKey].push(newTexture);
|
|
this.numUsedTextures++;
|
|
this._numBytesAllocated += texBytes;
|
|
this.log();
|
|
return newTexture;
|
|
};
|
|
TextureManager.prototype.releaseTexture = function (texture, shape, logicalTexType, isPacked) {
|
|
if (this.freeTextures == null) {
|
|
// Already disposed.
|
|
return;
|
|
}
|
|
var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked);
|
|
var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked);
|
|
if (!(shapeKey in this.freeTextures)) {
|
|
this.freeTextures[shapeKey] = [];
|
|
}
|
|
var texBytes = computeBytes(shape, physicalTexType, this.gpgpu.gl, this.gpgpu.textureConfig, isPacked);
|
|
var deleteTexThreshold = tf.env().get('WEBGL_DELETE_TEXTURE_THRESHOLD');
|
|
if (deleteTexThreshold !== -1 &&
|
|
this._numBytesAllocated > deleteTexThreshold) {
|
|
this.gpgpu.deleteMatrixTexture(texture);
|
|
this._numBytesAllocated -= texBytes;
|
|
}
|
|
else {
|
|
this.freeTextures[shapeKey].push(texture);
|
|
this.numFreeTextures++;
|
|
this._numBytesFree += texBytes;
|
|
}
|
|
this.numUsedTextures--;
|
|
var texList = this.usedTextures[shapeKey];
|
|
var texIndex = texList.indexOf(texture);
|
|
if (texIndex < 0) {
|
|
throw new Error('Cannot release a texture that was never provided by this ' +
|
|
'texture manager');
|
|
}
|
|
texList.splice(texIndex, 1);
|
|
this.log();
|
|
};
|
|
TextureManager.prototype.log = function () {
|
|
if (!this.logEnabled) {
|
|
return;
|
|
}
|
|
var total = this.numFreeTextures + this.numUsedTextures;
|
|
console.log('Free/Used', this.numFreeTextures + " / " + this.numUsedTextures, "(" + total + ")");
|
|
var freeRatio = this._numBytesFree / this._numBytesAllocated;
|
|
console.log("Bytes allocated: " + this._numBytesAllocated);
|
|
console.log("Bytes unused: " + this._numBytesFree + " (" + Math.round(100 * freeRatio) + "%)");
|
|
};
|
|
Object.defineProperty(TextureManager.prototype, "numBytesAllocated", {
|
|
get: function () {
|
|
return this._numBytesAllocated;
|
|
},
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
Object.defineProperty(TextureManager.prototype, "numBytesFree", {
|
|
get: function () {
|
|
return this._numBytesFree;
|
|
},
|
|
enumerable: true,
|
|
configurable: true
|
|
});
|
|
TextureManager.prototype.getNumUsedTextures = function () {
|
|
return this.numUsedTextures;
|
|
};
|
|
TextureManager.prototype.getNumFreeTextures = function () {
|
|
return this.numFreeTextures;
|
|
};
|
|
TextureManager.prototype.dispose = function () {
|
|
var _this = this;
|
|
if (this.freeTextures == null) {
|
|
// Already disposed.
|
|
return;
|
|
}
|
|
for (var texShape in this.freeTextures) {
|
|
this.freeTextures[texShape].forEach(function (tex) {
|
|
_this.gpgpu.deleteMatrixTexture(tex);
|
|
});
|
|
}
|
|
for (var texShape in this.usedTextures) {
|
|
this.usedTextures[texShape].forEach(function (tex) {
|
|
_this.gpgpu.deleteMatrixTexture(tex);
|
|
});
|
|
}
|
|
this.freeTextures = null;
|
|
this.usedTextures = null;
|
|
this.numUsedTextures = 0;
|
|
this.numFreeTextures = 0;
|
|
this._numBytesAllocated = 0;
|
|
this._numBytesFree = 0;
|
|
};
|
|
return TextureManager;
|
|
}());
|
|
function numBytesForInternalFormat(gl, internalFormat) {
|
|
// tslint:disable-next-line:no-any
|
|
var glany = gl;
|
|
if (internalFormat === glany.R32F) {
|
|
return 4;
|
|
}
|
|
else if (internalFormat === glany.R16F) {
|
|
return 2;
|
|
}
|
|
else if (internalFormat === glany.RGBA32F) {
|
|
return 16;
|
|
}
|
|
else if (internalFormat === gl.RGBA) {
|
|
return 16;
|
|
}
|
|
else if (internalFormat === glany.RGBA16F) {
|
|
return 8;
|
|
}
|
|
throw new Error("Unknown internal format " + internalFormat);
|
|
}
|
|
function computeBytes(shape, physicalTexType, gl, textureConfig, isPacked) {
|
|
// It is not possible to infer packed status from the texture type because
|
|
// depending on the textureConfig, different texture types may resolve to the
|
|
// same internal format (e.g. in WebGL1, the internal format for
|
|
// UNPACKED_FLOAT16 textures is gl.RGBA). Therefore we pass in `isPacked`
|
|
// explicitly.
|
|
var internalFormat = internalFormatForPhysicalTexType(physicalTexType, textureConfig);
|
|
var numElements;
|
|
if (isPacked) {
|
|
var _a = getPackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), packedWidth = _a[0], packedHeight = _a[1];
|
|
numElements = packedWidth * packedHeight;
|
|
}
|
|
else {
|
|
var _b = getUnpackedMatrixTextureShapeWidthHeight(shape[0], shape[1]), width = _b[0], height = _b[1];
|
|
numElements = width * height;
|
|
}
|
|
var bytesPerElement = numBytesForInternalFormat(gl, internalFormat);
|
|
return numElements * bytesPerElement;
|
|
}
|
|
function internalFormatForPhysicalTexType(physicalTexType, textureConfig) {
|
|
switch (physicalTexType) {
|
|
case PhysicalTextureType.PACKED_2X2_FLOAT32:
|
|
return getInternalFormatForPackedMatrixTexture(textureConfig);
|
|
case PhysicalTextureType.PACKED_2X2_FLOAT16:
|
|
return getInternalFormatForFloat16PackedMatrixTexture(textureConfig);
|
|
case PhysicalTextureType.UNPACKED_FLOAT32:
|
|
return getInternalFormatForFloat32MatrixTexture(textureConfig);
|
|
case PhysicalTextureType.UNPACKED_FLOAT16:
|
|
return getInternalFormatForFloat16MatrixTexture(textureConfig);
|
|
case PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE:
|
|
return getInternalFormatForUnsignedBytesMatrixTexture(textureConfig);
|
|
default:
|
|
throw new Error("Unknown physical texture type " + physicalTexType);
|
|
}
|
|
}
|
|
function getPhysicalTextureForRendering(isPacked) {
|
|
if (tf.env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) {
|
|
if (isPacked) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT32;
|
|
}
|
|
return PhysicalTextureType.UNPACKED_FLOAT32;
|
|
}
|
|
if (isPacked) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT16;
|
|
}
|
|
return PhysicalTextureType.UNPACKED_FLOAT16;
|
|
}
|
|
function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) {
|
|
if (logicalTexType === TextureUsage.UPLOAD) {
|
|
return PhysicalTextureType.PACKED_2X2_FLOAT32;
|
|
}
|
|
else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) {
|
|
return getPhysicalTextureForRendering(isPacked);
|
|
}
|
|
else if (logicalTexType === TextureUsage.DOWNLOAD ||
|
|
logicalTexType === TextureUsage.PIXELS) {
|
|
return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE;
|
|
}
|
|
throw new Error("Unknown logical texture type " + logicalTexType);
|
|
}
|
|
function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) {
|
|
return shapeRowsCol[0] + "_" + shapeRowsCol[1] + "_" + physicalTexType + "_" + isPacked;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var UnaryOpProgram = /** @class */ (function () {
|
|
function UnaryOpProgram(aShape, opSnippet) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = aShape;
|
|
this.userCode = "\n float unaryOperation(float x) {\n " + opSnippet + "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
|
|
}
|
|
return UnaryOpProgram;
|
|
}());
|
|
var CHECK_NAN_SNIPPET = "if (isnan(x)) return x;";
|
|
var LINEAR = "return x;";
|
|
var ABS = "return abs(x);";
|
|
var ELU = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
|
|
var RELU = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : x;\n";
|
|
var RELU6 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
|
|
var CLONE = 'return x;';
|
|
var SIGMOID = "return 1.0 / (1.0 + exp(-1.0 * x));";
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LINEAR$1 = "return x;";
|
|
var ELU$1 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
|
|
var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
|
|
var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
|
|
var SIGMOID$1 = "return 1.0 / (1.0 + exp(-1.0 * x));";
|
|
var UnaryOpPackedProgram = /** @class */ (function () {
|
|
function UnaryOpPackedProgram(aShape, opSnippet) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = aShape;
|
|
this.userCode = "\n vec4 unaryOperation(vec4 x) {\n " + opSnippet + "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n ";
|
|
}
|
|
return UnaryOpPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var UnpackProgram = /** @class */ (function () {
|
|
function UnpackProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = false;
|
|
this.outputShape = outputShape;
|
|
var rank = outputShape.length;
|
|
var channels = getChannels('rc', rank);
|
|
var dtype = getCoordsDataType(rank);
|
|
var sourceCoords = getSourceCoords(rank, channels);
|
|
var innerDims = channels.slice(-2);
|
|
var coords = rank <= 1 ? 'rc' : "vec2(" + innerDims.join(',') + ")";
|
|
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 packedInput = getA(" + sourceCoords + ");\n\n setOutput(getChannel(packedInput, " + coords + "));\n }\n ";
|
|
}
|
|
return UnpackProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var whereImpl = tf.kernel_impls.whereImpl;
|
|
var EPSILON_FLOAT32 = 1e-7;
|
|
var EPSILON_FLOAT16 = 1e-4;
|
|
var binaryCaches = {};
|
|
function getBinaryCache(webGLVersion) {
|
|
if (webGLVersion in binaryCaches) {
|
|
return binaryCaches[webGLVersion];
|
|
}
|
|
binaryCaches[webGLVersion] = {};
|
|
return binaryCaches[webGLVersion];
|
|
}
|
|
// Empirically determined constant used to determine size threshold for handing
|
|
// off execution to the CPU.
|
|
var CPU_HANDOFF_SIZE_THRESHOLD = tf.env().getNumber('CPU_HANDOFF_SIZE_THRESHOLD');
|
|
// Empirically determined constant used to decide the number of MB on GPU
|
|
// before we warn about high memory use. The MB are this constant * screen area
|
|
// * dpi / 1024 / 1024.
|
|
var BEFORE_PAGING_CONSTANT = 600;
|
|
function numMBBeforeWarning() {
|
|
if (tf.env().global.screen == null) {
|
|
return 1024; // 1 GB.
|
|
}
|
|
return (tf.env().global.screen.height * tf.env().global.screen.width *
|
|
window.devicePixelRatio) *
|
|
BEFORE_PAGING_CONSTANT / 1024 / 1024;
|
|
}
|
|
var MathBackendWebGL = /** @class */ (function (_super) {
|
|
__extends(MathBackendWebGL, _super);
|
|
function MathBackendWebGL(gpgpu) {
|
|
var _this = _super.call(this) || this;
|
|
// Maps data ids that have a pending read operation, to list of subscribers.
|
|
_this.pendingRead = new WeakMap();
|
|
// List of data ids that are scheduled for disposal, but are waiting on a
|
|
// pending read operation.
|
|
_this.pendingDisposal = new WeakSet();
|
|
// Used to count the number of 'shallow' sliced tensors that point to the
|
|
// same data id.
|
|
_this.dataRefCount = new WeakMap();
|
|
_this.numBytesInGPU = 0;
|
|
// Accumulated time spent (including blocking) in uploading data to webgl.
|
|
_this.uploadWaitMs = 0;
|
|
// Accumulated time spent (including blocking in downloading data from webgl.
|
|
_this.downloadWaitMs = 0;
|
|
// record the last manual GL Flush time.
|
|
_this.lastGlFlushTime = 0;
|
|
_this.warnedAboutMemory = false;
|
|
_this.pendingDeletes = 0;
|
|
_this.disposed = false;
|
|
if (!tf.env().getBool('HAS_WEBGL')) {
|
|
throw new Error('WebGL is not supported on this device');
|
|
}
|
|
if (gpgpu == null) {
|
|
var gl = getWebGLContext(tf.env().getNumber('WEBGL_VERSION'));
|
|
_this.binaryCache = getBinaryCache(tf.env().getNumber('WEBGL_VERSION'));
|
|
_this.gpgpu = new GPGPUContext(gl);
|
|
_this.canvas = gl.canvas;
|
|
_this.gpgpuCreatedLocally = true;
|
|
}
|
|
else {
|
|
_this.gpgpu = gpgpu;
|
|
_this.binaryCache = {};
|
|
_this.gpgpuCreatedLocally = false;
|
|
_this.canvas = gpgpu.gl.canvas;
|
|
}
|
|
_this.textureManager = new TextureManager(_this.gpgpu);
|
|
_this.numMBBeforeWarning = numMBBeforeWarning();
|
|
_this.texData = new tf.DataStorage(_this, tf.engine());
|
|
return _this;
|
|
}
|
|
MathBackendWebGL.prototype.nextDataId = function () {
|
|
return MathBackendWebGL.nextDataId++;
|
|
};
|
|
MathBackendWebGL.prototype.numDataIds = function () {
|
|
return this.texData.numDataIds() - this.pendingDeletes;
|
|
};
|
|
MathBackendWebGL.prototype.write = function (values, shape, dtype) {
|
|
if (tf.env().getBool('WEBGL_CHECK_NUMERICAL_PROBLEMS') ||
|
|
tf.env().getBool('DEBUG')) {
|
|
this.checkNumericalProblems(values);
|
|
}
|
|
if (dtype === 'complex64' && values != null) {
|
|
throw new Error("Cannot write to a complex64 dtype. " +
|
|
"Please use tf.complex(real, imag).");
|
|
}
|
|
var dataId = { id: this.nextDataId() };
|
|
this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD, refCount: 1 });
|
|
return dataId;
|
|
};
|
|
/** Return refCount of a `TensorData`. */
|
|
MathBackendWebGL.prototype.refCount = function (dataId) {
|
|
if (this.texData.has(dataId)) {
|
|
var tensorData = this.texData.get(dataId);
|
|
return tensorData.refCount;
|
|
}
|
|
return 0;
|
|
};
|
|
/** Increase refCount of a `TextureData`. */
|
|
MathBackendWebGL.prototype.incRef = function (dataId) {
|
|
var texData = this.texData.get(dataId);
|
|
texData.refCount++;
|
|
};
|
|
/** Decrease refCount of a `TextureData`. */
|
|
MathBackendWebGL.prototype.decRef = function (dataId) {
|
|
if (this.texData.has(dataId)) {
|
|
var texData = this.texData.get(dataId);
|
|
texData.refCount--;
|
|
}
|
|
};
|
|
MathBackendWebGL.prototype.move = function (dataId, values, shape, dtype, refCount) {
|
|
if (tf.env().getBool('DEBUG')) {
|
|
this.checkNumericalProblems(values);
|
|
}
|
|
if (dtype === 'complex64') {
|
|
throw new Error("Cannot write to a complex64 dtype. " +
|
|
"Please use tf.complex(real, imag).");
|
|
}
|
|
this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD, refCount: refCount });
|
|
};
|
|
MathBackendWebGL.prototype.disposeIntermediateTensorInfo = function (tensorInfo) {
|
|
this.disposeData(tensorInfo.dataId);
|
|
};
|
|
MathBackendWebGL.prototype.readSync = function (dataId) {
|
|
var texData = this.texData.get(dataId);
|
|
var values = texData.values, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, slice = texData.slice, shape = texData.shape, isPacked = texData.isPacked;
|
|
// The presence of `slice` indicates this tensor is a shallow slice of a
|
|
// different tensor, and is using that original tensor's texture. Run
|
|
// `clone` in order to copy that texture and read from it.
|
|
if (slice != null) {
|
|
var program = void 0;
|
|
if (isPacked) {
|
|
program = new UnaryOpPackedProgram(shape, CLONE);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(shape, CLONE);
|
|
}
|
|
var res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype);
|
|
var data = this.readSync(res.dataId);
|
|
this.disposeIntermediateTensorInfo(res);
|
|
return data;
|
|
}
|
|
if (values != null) {
|
|
return this.convertAndCacheOnCPU(dataId);
|
|
}
|
|
if (dtype === 'string') {
|
|
return values;
|
|
}
|
|
var shouldTimeProgram = this.activeTimers != null;
|
|
var start;
|
|
if (shouldTimeProgram) {
|
|
start = tf.util.now();
|
|
}
|
|
var result;
|
|
if (dtype === 'complex64') {
|
|
var realValues = this.readSync(complexTensorInfos.real.dataId);
|
|
var imagValues = this.readSync(complexTensorInfos.imag.dataId);
|
|
result = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues);
|
|
}
|
|
else {
|
|
result = this.getValuesFromTexture(dataId);
|
|
}
|
|
if (shouldTimeProgram) {
|
|
this.downloadWaitMs += tf.util.now() - start;
|
|
}
|
|
return this.convertAndCacheOnCPU(dataId, result);
|
|
};
|
|
MathBackendWebGL.prototype.read = function (dataId) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var subscribers_1, texData, values, shape, slice, dtype, complexTensorInfos, isPacked, program, res, data, buffer, tmpDownloadTarget, tmpData, vals, ps, realValues, imagValues, size, dTypeVals, subscribers;
|
|
var _a;
|
|
return __generator(this, function (_b) {
|
|
switch (_b.label) {
|
|
case 0:
|
|
if (this.pendingRead.has(dataId)) {
|
|
subscribers_1 = this.pendingRead.get(dataId);
|
|
return [2 /*return*/, new Promise(function (resolve) { return subscribers_1.push(resolve); })];
|
|
}
|
|
texData = this.texData.get(dataId);
|
|
values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensorInfos = texData.complexTensorInfos, isPacked = texData.isPacked;
|
|
// The presence of `slice` indicates this tensor is a shallow slice of a
|
|
// different tensor, and is using that original tensor's texture. Run
|
|
// `clone` in order to copy that texture and read from it.
|
|
if (slice != null) {
|
|
program = void 0;
|
|
if (isPacked) {
|
|
program = new UnaryOpPackedProgram(shape, CLONE);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(shape, CLONE);
|
|
}
|
|
res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype);
|
|
data = this.read(res.dataId);
|
|
this.disposeIntermediateTensorInfo(res);
|
|
return [2 /*return*/, data];
|
|
}
|
|
if (values != null) {
|
|
return [2 /*return*/, this.convertAndCacheOnCPU(dataId)];
|
|
}
|
|
if (!tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') &&
|
|
tf.env().getNumber('WEBGL_VERSION') === 2) {
|
|
throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " +
|
|
"WEBGL_VERSION=2 not yet supported.");
|
|
}
|
|
buffer = null;
|
|
if (dtype !== 'complex64' && tf.env().get('WEBGL_BUFFER_SUPPORTED')) {
|
|
// Possibly copy the texture into a buffer before inserting a fence.
|
|
tmpDownloadTarget = this.decode(dataId);
|
|
tmpData = this.texData.get(tmpDownloadTarget.dataId);
|
|
buffer = (_a = this.gpgpu).createBufferFromTexture.apply(_a, [tmpData.texture].concat(getDenseTexShape(shape)));
|
|
}
|
|
this.pendingRead.set(dataId, []);
|
|
if (!(dtype !== 'complex64')) return [3 /*break*/, 2];
|
|
// Create a fence and wait for it to resolve.
|
|
return [4 /*yield*/, this.gpgpu.createAndWaitForFence()];
|
|
case 1:
|
|
// Create a fence and wait for it to resolve.
|
|
_b.sent();
|
|
_b.label = 2;
|
|
case 2:
|
|
if (!(dtype === 'complex64')) return [3 /*break*/, 4];
|
|
return [4 /*yield*/, Promise.all([
|
|
this.read(complexTensorInfos.real.dataId),
|
|
this.read(complexTensorInfos.imag.dataId)
|
|
])];
|
|
case 3:
|
|
ps = _b.sent();
|
|
realValues = ps[0];
|
|
imagValues = ps[1];
|
|
vals = tf.backend_util.mergeRealAndImagArrays(realValues, imagValues);
|
|
return [3 /*break*/, 5];
|
|
case 4:
|
|
if (buffer == null) {
|
|
vals = this.getValuesFromTexture(dataId);
|
|
}
|
|
else {
|
|
size = tf.util.sizeFromShape(shape);
|
|
vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size);
|
|
}
|
|
_b.label = 5;
|
|
case 5:
|
|
if (tmpDownloadTarget != null) {
|
|
this.disposeIntermediateTensorInfo(tmpDownloadTarget);
|
|
}
|
|
dTypeVals = this.convertAndCacheOnCPU(dataId, vals);
|
|
subscribers = this.pendingRead.get(dataId);
|
|
this.pendingRead.delete(dataId);
|
|
// Notify all pending reads.
|
|
subscribers.forEach(function (resolve) { return resolve(dTypeVals); });
|
|
if (this.pendingDisposal.has(dataId)) {
|
|
this.pendingDisposal.delete(dataId);
|
|
if (this.disposeData(dataId)) {
|
|
tf.engine().removeDataId(dataId, this);
|
|
}
|
|
this.pendingDeletes--;
|
|
}
|
|
return [2 /*return*/, dTypeVals];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
MathBackendWebGL.prototype.bufferSync = function (t) {
|
|
var data = this.readSync(t.dataId);
|
|
var decodedData = data;
|
|
if (t.dtype === 'string') {
|
|
try {
|
|
// Decode the bytes into string.
|
|
decodedData = data.map(function (d) { return tf.util.decodeString(d); });
|
|
}
|
|
catch (_a) {
|
|
throw new Error('Failed to decode encoded string bytes into utf-8');
|
|
}
|
|
}
|
|
return tf.buffer(t.shape, t.dtype, decodedData);
|
|
};
|
|
MathBackendWebGL.prototype.checkNumericalProblems = function (values) {
|
|
if (values == null) {
|
|
return;
|
|
}
|
|
for (var i = 0; i < values.length; i++) {
|
|
var num = values[i];
|
|
if (!canBeRepresented(num)) {
|
|
if (tf.env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) {
|
|
throw Error("The value " + num + " cannot be represented with your " +
|
|
"current settings. Consider enabling float32 rendering: " +
|
|
"'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'");
|
|
}
|
|
throw Error("The value " + num + " cannot be represented on this device.");
|
|
}
|
|
}
|
|
};
|
|
MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) {
|
|
var _a;
|
|
var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, isPacked = _b.isPacked;
|
|
var size = tf.util.sizeFromShape(shape);
|
|
if (tf.env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) {
|
|
var tmpTarget = this.decode(dataId);
|
|
var tmpData_1 = this.texData.get(tmpTarget.dataId);
|
|
var vals_1 = (_a = this.gpgpu).downloadMatrixFromPackedTexture.apply(_a, [tmpData_1.texture].concat(getDenseTexShape(shape))).subarray(0, size);
|
|
this.disposeIntermediateTensorInfo(tmpTarget);
|
|
return vals_1;
|
|
}
|
|
var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK') && isPacked === true;
|
|
var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape;
|
|
var program = shouldUsePackedProgram ?
|
|
new EncodeFloatPackedProgram(outputShape) :
|
|
new EncodeFloatProgram(outputShape);
|
|
var output = this.runWebGLProgram(program, [{ shape: outputShape, dtype: dtype, dataId: dataId }], 'float32');
|
|
var tmpData = this.texData.get(output.dataId);
|
|
var vals = this.gpgpu
|
|
.downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1])
|
|
.subarray(0, size);
|
|
this.disposeIntermediateTensorInfo(output);
|
|
return vals;
|
|
};
|
|
MathBackendWebGL.prototype.timerAvailable = function () {
|
|
return tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0;
|
|
};
|
|
MathBackendWebGL.prototype.time = function (f) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, res, kernelMs_1;
|
|
return __generator(this, function (_a) {
|
|
switch (_a.label) {
|
|
case 0:
|
|
oldActiveTimers = this.activeTimers;
|
|
newActiveTimers = [];
|
|
outerMostTime = false;
|
|
if (this.programTimersStack == null) {
|
|
this.programTimersStack = newActiveTimers;
|
|
outerMostTime = true;
|
|
}
|
|
else {
|
|
this.activeTimers.push(newActiveTimers);
|
|
}
|
|
this.activeTimers = newActiveTimers;
|
|
f();
|
|
flattenedActiveTimerQueries = tf.util.flatten(this.activeTimers.map(function (d) { return d.query; }))
|
|
.filter(function (d) { return d != null; });
|
|
flattenedActiveTimerNames = tf.util.flatten(this.activeTimers.map(function (d) { return d.name; }))
|
|
.filter(function (d) { return d != null; });
|
|
this.activeTimers = oldActiveTimers;
|
|
if (outerMostTime) {
|
|
this.programTimersStack = null;
|
|
}
|
|
res = {
|
|
uploadWaitMs: this.uploadWaitMs,
|
|
downloadWaitMs: this.downloadWaitMs,
|
|
kernelMs: null,
|
|
wallMs: null // will be filled by the engine
|
|
};
|
|
if (!(tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0)) return [3 /*break*/, 2];
|
|
return [4 /*yield*/, Promise.all(flattenedActiveTimerQueries)];
|
|
case 1:
|
|
kernelMs_1 = _a.sent();
|
|
res['kernelMs'] = tf.util.sum(kernelMs_1);
|
|
res['getExtraProfileInfo'] = function () {
|
|
return kernelMs_1.map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); })
|
|
.map(function (d) { return d.name + ": " + d.ms; })
|
|
.join(', ');
|
|
};
|
|
return [3 /*break*/, 3];
|
|
case 2:
|
|
res['kernelMs'] = {
|
|
error: 'WebGL query timers are not supported in this environment.'
|
|
};
|
|
_a.label = 3;
|
|
case 3:
|
|
this.uploadWaitMs = 0;
|
|
this.downloadWaitMs = 0;
|
|
return [2 /*return*/, res];
|
|
}
|
|
});
|
|
});
|
|
};
|
|
MathBackendWebGL.prototype.memory = function () {
|
|
return {
|
|
unreliable: false,
|
|
numBytesInGPU: this.numBytesInGPU,
|
|
numBytesInGPUAllocated: this.textureManager.numBytesAllocated,
|
|
numBytesInGPUFree: this.textureManager.numBytesFree
|
|
};
|
|
};
|
|
MathBackendWebGL.prototype.startTimer = function () {
|
|
if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
return this.gpgpu.beginQuery();
|
|
}
|
|
return { startMs: tf.util.now(), endMs: null };
|
|
};
|
|
MathBackendWebGL.prototype.endTimer = function (query) {
|
|
if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
this.gpgpu.endQuery();
|
|
return query;
|
|
}
|
|
query.endMs = tf.util.now();
|
|
return query;
|
|
};
|
|
MathBackendWebGL.prototype.getQueryTime = function (query) {
|
|
return __awaiter(this, void 0, void 0, function () {
|
|
var timerQuery;
|
|
return __generator(this, function (_a) {
|
|
if (tf.env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE') > 0) {
|
|
return [2 /*return*/, this.gpgpu.waitForQueryAndGetTime(query)];
|
|
}
|
|
timerQuery = query;
|
|
return [2 /*return*/, timerQuery.endMs - timerQuery.startMs];
|
|
});
|
|
});
|
|
};
|
|
/**
|
|
* Decrease the RefCount on the dataId and dispose the memory if the dataId
|
|
* has 0 refCount. If there are pending read on the data, the disposal would
|
|
* added to the pending delete queue. Return true if the dataId is removed
|
|
* from backend or the backend does not contain the dataId, false if the
|
|
* dataId is not removed. Memory may or may not be released even when dataId
|
|
* is removed, which also depends on dataRefCount, see `releaseGPU`.
|
|
* @param dataId
|
|
* @oaram force Optional, remove the data regardless of refCount
|
|
*/
|
|
MathBackendWebGL.prototype.disposeData = function (dataId, force) {
|
|
if (force === void 0) { force = false; }
|
|
if (this.pendingDisposal.has(dataId)) {
|
|
return false;
|
|
}
|
|
// No-op if already disposed.
|
|
if (!this.texData.has(dataId)) {
|
|
return true;
|
|
}
|
|
// if force flag is set, change refCount to 0, this would ensure disposal
|
|
// when added to the pendingDisposal queue. Memory may or may not be
|
|
// released, which also depends on dataRefCount, see `releaseGPU`.
|
|
if (force) {
|
|
this.texData.get(dataId).refCount = 0;
|
|
}
|
|
else {
|
|
this.texData.get(dataId).refCount--;
|
|
}
|
|
if (!force && this.texData.get(dataId).refCount > 0) {
|
|
return false;
|
|
}
|
|
if (this.pendingRead.has(dataId)) {
|
|
this.pendingDisposal.add(dataId);
|
|
this.pendingDeletes++;
|
|
return false;
|
|
}
|
|
this.releaseGPUData(dataId);
|
|
var complexTensorInfos = this.texData.get(dataId).complexTensorInfos;
|
|
if (complexTensorInfos != null) {
|
|
this.disposeData(complexTensorInfos.real.dataId, force);
|
|
this.disposeData(complexTensorInfos.imag.dataId, force);
|
|
}
|
|
this.texData.delete(dataId);
|
|
return true;
|
|
};
|
|
MathBackendWebGL.prototype.releaseGPUData = function (dataId) {
|
|
var _a = this.texData.get(dataId), texture = _a.texture, dtype = _a.dtype, texShape = _a.texShape, usage = _a.usage, isPacked = _a.isPacked, slice = _a.slice;
|
|
var key = slice && slice.origDataId || dataId;
|
|
var refCount = this.dataRefCount.get(key);
|
|
if (refCount > 1) {
|
|
this.dataRefCount.set(key, refCount - 1);
|
|
}
|
|
else {
|
|
this.dataRefCount.delete(key);
|
|
if (texture != null) {
|
|
this.numBytesInGPU -= this.computeBytes(texShape, dtype);
|
|
this.textureManager.releaseTexture(texture, texShape, usage, isPacked);
|
|
}
|
|
}
|
|
var texData = this.texData.get(dataId);
|
|
texData.texture = null;
|
|
texData.texShape = null;
|
|
texData.isPacked = false;
|
|
texData.slice = null;
|
|
};
|
|
MathBackendWebGL.prototype.getTexture = function (dataId) {
|
|
this.uploadToGPU(dataId);
|
|
return this.texData.get(dataId).texture;
|
|
};
|
|
/**
|
|
* Returns internal information for the specific data bucket. Used in unit
|
|
* tests.
|
|
*/
|
|
MathBackendWebGL.prototype.getDataInfo = function (dataId) {
|
|
return this.texData.get(dataId);
|
|
};
|
|
/*
|
|
Tests whether all the inputs to an op are small and on the CPU. This heuristic
|
|
determines when it would be faster to execute a kernel on the CPU. WebGL
|
|
kernels opt into running this check and forwarding when appropriate.
|
|
TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more
|
|
sustainable strategy for optimizing backend execution of ops.
|
|
*/
|
|
MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) {
|
|
var _this = this;
|
|
if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; }
|
|
return tf.env().getBool('WEBGL_CPU_FORWARD') &&
|
|
inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null &&
|
|
tf.util.sizeFromShape(input.shape) < sizeThreshold; });
|
|
};
|
|
MathBackendWebGL.prototype.getGPGPUContext = function () {
|
|
return this.gpgpu;
|
|
};
|
|
MathBackendWebGL.prototype.where = function (condition) {
|
|
tf.backend_util.warn('tf.where() in webgl locks the UI thread. ' +
|
|
'Call tf.whereAsync() instead');
|
|
var condVals = condition.dataSync();
|
|
return whereImpl(condition.shape, condVals);
|
|
};
|
|
MathBackendWebGL.prototype.packedUnaryOp = function (x, op, dtype) {
|
|
var program = new UnaryOpPackedProgram(x.shape, op);
|
|
var outInfo = this.compileAndRun(program, [x], dtype);
|
|
return tf.engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
|
|
};
|
|
// TODO(msoulanille) remove this once the backend has been modularized
|
|
// a copy is needed here to break a circular dependency.
|
|
// Also remove the op from unary_op.
|
|
MathBackendWebGL.prototype.abs = function (x) {
|
|
// TODO: handle cases when x is complex.
|
|
if (this.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
|
|
var outValues = simpleAbsImplCPU(this.texData.get(x.dataId).values);
|
|
return this.makeOutput(x.shape, x.dtype, outValues);
|
|
}
|
|
if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
return this.packedUnaryOp(x, ABS, x.dtype);
|
|
}
|
|
var program = new UnaryOpProgram(x.shape, ABS);
|
|
var outInfo = this.compileAndRun(program, [x]);
|
|
return tf.engine().makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype);
|
|
};
|
|
MathBackendWebGL.prototype.makeTensorInfo = function (shape, dtype, values) {
|
|
var dataId;
|
|
if (dtype === 'string' && values != null && values.length > 0 &&
|
|
tf.util.isString(values[0])) {
|
|
var encodedValues = values.map(function (d) { return tf.util.encodeString(d); });
|
|
dataId = this.write(encodedValues, shape, dtype);
|
|
}
|
|
else {
|
|
dataId = this.write(values, shape, dtype);
|
|
}
|
|
this.texData.get(dataId).usage = null;
|
|
return { dataId: dataId, shape: shape, dtype: dtype };
|
|
};
|
|
MathBackendWebGL.prototype.makeOutput = function (shape, dtype, values) {
|
|
var dataId = this.makeTensorInfo(shape, dtype, values).dataId;
|
|
return tf.engine().makeTensorFromDataId(dataId, shape, dtype, this);
|
|
};
|
|
MathBackendWebGL.prototype.unpackTensor = function (input) {
|
|
var program = new UnpackProgram(input.shape);
|
|
return this.runWebGLProgram(program, [input], input.dtype);
|
|
};
|
|
MathBackendWebGL.prototype.packTensor = function (input) {
|
|
var program = new PackProgram(input.shape);
|
|
var preventEagerUnpackingOutput = true;
|
|
return this.runWebGLProgram(program, [input], input.dtype, null /* customSetup */, preventEagerUnpackingOutput);
|
|
};
|
|
MathBackendWebGL.prototype.packedReshape = function (input, afterShape) {
|
|
var input3DShape = [
|
|
getBatchDim(input.shape)
|
|
].concat(getRowsCols(input.shape));
|
|
var input3D = {
|
|
dtype: input.dtype,
|
|
shape: input3DShape,
|
|
dataId: input.dataId
|
|
};
|
|
var afterShapeAs3D = [
|
|
getBatchDim(afterShape)
|
|
].concat(getRowsCols(afterShape));
|
|
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
|
|
var preventEagerUnpackingOfOutput = true;
|
|
var output = this.runWebGLProgram(program, [input3D], input.dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
|
|
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
|
|
};
|
|
MathBackendWebGL.prototype.decode = function (dataId) {
|
|
var texData = this.texData.get(dataId);
|
|
var isPacked = texData.isPacked, shape = texData.shape, dtype = texData.dtype;
|
|
var shapeAs3D = getShapeAs3D(shape);
|
|
var program;
|
|
if (isPacked) {
|
|
program = new DecodeMatrixPackedProgram(shapeAs3D);
|
|
}
|
|
else {
|
|
program = new DecodeMatrixProgram(shapeAs3D);
|
|
}
|
|
var preventEagerUnpackingOfOutput = true;
|
|
var out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype: dtype, dataId: dataId }], dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
|
|
return { dtype: dtype, shape: shape, dataId: out.dataId };
|
|
};
|
|
MathBackendWebGL.prototype.runWebGLProgram = function (program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) {
|
|
var _this = this;
|
|
if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; }
|
|
var output = this.makeTensorInfo(program.outputShape, outputDtype);
|
|
var outData = this.texData.get(output.dataId);
|
|
if (program.packedOutput) {
|
|
outData.isPacked = true;
|
|
}
|
|
if (program.outPackingScheme === PackingScheme.DENSE) {
|
|
var texelShape = getDenseTexShape(program.outputShape);
|
|
// For a densely packed output, we explicitly set texShape
|
|
// so it doesn't get assigned later according to our typical packing
|
|
// scheme wherein a single texel can only contain values from adjacent
|
|
// rows/cols.
|
|
outData.texShape = texelShape.map(function (d) { return d * 2; });
|
|
}
|
|
if (program.outTexUsage != null) {
|
|
outData.usage = program.outTexUsage;
|
|
}
|
|
if (tf.util.sizeFromShape(output.shape) === 0) {
|
|
// Short-circuit the computation since the result is empty (has 0 in its
|
|
// shape).
|
|
outData.values =
|
|
tf.util.getTypedArrayFromDType(output.dtype, 0);
|
|
return output;
|
|
}
|
|
var dataToDispose = [];
|
|
var inputsData = inputs.map(function (input) {
|
|
if (input.dtype === 'complex64') {
|
|
throw new Error("GPGPUProgram does not support complex64 input. For complex64 " +
|
|
"dtypes, please separate the program into real and imaginary " +
|
|
"parts.");
|
|
}
|
|
var texData = _this.texData.get(input.dataId);
|
|
if (texData.texture == null) {
|
|
if (!program.packedInputs &&
|
|
tf.util.sizeFromShape(input.shape) <=
|
|
tf.env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) {
|
|
// Upload small tensors that live on the CPU as uniforms, not as
|
|
// textures. Do this only when the environment supports 32bit floats
|
|
// due to problems when comparing 16bit floats with 32bit floats.
|
|
// TODO(https://github.com/tensorflow/tfjs/issues/821): Make it
|
|
// possible for packed shaders to sample from uniforms.
|
|
return {
|
|
shape: input.shape,
|
|
texData: null,
|
|
isUniform: true,
|
|
uniformValues: texData.values
|
|
};
|
|
}
|
|
// This ensures that if a packed program's inputs have not yet been
|
|
// uploaded to the GPU, they get uploaded as packed right off the bat.
|
|
if (program.packedInputs) {
|
|
texData.isPacked = true;
|
|
texData.shape = input.shape;
|
|
}
|
|
}
|
|
else if (!!texData.isPacked !== !!program.packedInputs) {
|
|
input = texData.isPacked ? _this.unpackTensor(input) :
|
|
_this.packTensor(input);
|
|
dataToDispose.push(input);
|
|
texData = _this.texData.get(input.dataId);
|
|
}
|
|
else if (texData.isPacked &&
|
|
!isReshapeFree(texData.shape, input.shape)) {
|
|
// This is a special case where a texture exists for a tensor
|
|
// but the shapes are incompatible (due to packing constraints) because
|
|
// the tensor did not have a chance to go through the packed reshape
|
|
// shader. This only happens when we reshape the *same* tensor to form
|
|
// *distinct* inputs to an op, e.g. dotting a vector with itself. This
|
|
// case will disappear once packed uploading is the default.
|
|
var savedInput = input;
|
|
var targetShape = input.shape;
|
|
input.shape = texData.shape;
|
|
input = _this.packedReshape(input, targetShape);
|
|
dataToDispose.push(input);
|
|
texData = _this.texData.get(input.dataId);
|
|
savedInput.shape = targetShape;
|
|
}
|
|
_this.uploadToGPU(input.dataId);
|
|
return { shape: input.shape, texData: texData, isUniform: false };
|
|
});
|
|
this.uploadToGPU(output.dataId);
|
|
var outputData = { shape: output.shape, texData: outData, isUniform: false };
|
|
var key = makeShaderKey(program, inputsData, outputData);
|
|
var binary = this.getAndSaveBinary(key, function () {
|
|
return compileProgram(_this.gpgpu, program, inputsData, outputData);
|
|
});
|
|
var shouldTimeProgram = this.activeTimers != null;
|
|
var query;
|
|
if (shouldTimeProgram) {
|
|
query = this.startTimer();
|
|
}
|
|
runProgram(this.gpgpu, binary, inputsData, outputData, customSetup);
|
|
dataToDispose.forEach(function (info) { return _this.disposeIntermediateTensorInfo(info); });
|
|
if (shouldTimeProgram) {
|
|
query = this.endTimer(query);
|
|
this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) });
|
|
}
|
|
var glFlushThreshold = tf.env().get('WEBGL_FLUSH_THRESHOLD');
|
|
// Manually GL flush requested
|
|
if (glFlushThreshold > 0) {
|
|
var time = tf.util.now();
|
|
if ((time - this.lastGlFlushTime) > glFlushThreshold) {
|
|
this.gpgpu.gl.flush();
|
|
this.lastGlFlushTime = time;
|
|
}
|
|
}
|
|
if (!tf.env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked &&
|
|
preventEagerUnpackingOfOutput === false) {
|
|
var unpacked = this.unpackTensor(output);
|
|
this.disposeIntermediateTensorInfo(output);
|
|
return unpacked;
|
|
}
|
|
return output;
|
|
};
|
|
MathBackendWebGL.prototype.compileAndRun = function (program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) {
|
|
if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; }
|
|
outputDtype = outputDtype || inputs[0].dtype;
|
|
var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput);
|
|
return outInfo;
|
|
};
|
|
MathBackendWebGL.prototype.getAndSaveBinary = function (key, getBinary) {
|
|
if (!(key in this.binaryCache)) {
|
|
this.binaryCache[key] = getBinary();
|
|
}
|
|
return this.binaryCache[key];
|
|
};
|
|
MathBackendWebGL.prototype.getTextureManager = function () {
|
|
return this.textureManager;
|
|
};
|
|
MathBackendWebGL.prototype.dispose = function () {
|
|
var _this = this;
|
|
if (this.disposed) {
|
|
return;
|
|
}
|
|
// Avoid disposing the compiled webgl programs during unit testing because
|
|
// it slows down test execution.
|
|
if (!tf.env().getBool('IS_TEST')) {
|
|
var allKeys = Object.keys(this.binaryCache);
|
|
allKeys.forEach(function (key) {
|
|
_this.gpgpu.deleteProgram(_this.binaryCache[key].webGLProgram);
|
|
delete _this.binaryCache[key];
|
|
});
|
|
}
|
|
this.textureManager.dispose();
|
|
if (this.canvas != null &&
|
|
(typeof (HTMLCanvasElement) !== 'undefined' &&
|
|
this.canvas instanceof HTMLCanvasElement)) {
|
|
this.canvas.remove();
|
|
}
|
|
else {
|
|
this.canvas = null;
|
|
}
|
|
if (this.gpgpuCreatedLocally) {
|
|
this.gpgpu.program = null;
|
|
this.gpgpu.dispose();
|
|
}
|
|
this.disposed = true;
|
|
};
|
|
MathBackendWebGL.prototype.floatPrecision = function () {
|
|
var _this = this;
|
|
if (this.floatPrecisionValue == null) {
|
|
this.floatPrecisionValue = tf.tidy(function () {
|
|
if (!tf.env().get('WEBGL_RENDER_FLOAT32_ENABLED')) {
|
|
// Momentarily switching DEBUG flag to false so we don't throw an
|
|
// error trying to upload a small value.
|
|
var debugFlag = tf.env().getBool('DEBUG');
|
|
tf.env().set('DEBUG', false);
|
|
var underflowCheckValue = _this.abs(tf.scalar(1e-8)).dataSync()[0];
|
|
tf.env().set('DEBUG', debugFlag);
|
|
if (underflowCheckValue > 0) {
|
|
return 32;
|
|
}
|
|
}
|
|
return 16;
|
|
});
|
|
}
|
|
return this.floatPrecisionValue;
|
|
};
|
|
/** Returns the smallest representable number. */
|
|
MathBackendWebGL.prototype.epsilon = function () {
|
|
return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16;
|
|
};
|
|
MathBackendWebGL.prototype.uploadToGPU = function (dataId) {
|
|
var _a;
|
|
var texData = this.texData.get(dataId);
|
|
var shape = texData.shape, dtype = texData.dtype, values = texData.values, texture = texData.texture, usage = texData.usage, isPacked = texData.isPacked;
|
|
if (texture != null) {
|
|
// Array is already on GPU. No-op.
|
|
return;
|
|
}
|
|
var shouldTimeProgram = this.activeTimers != null;
|
|
var start;
|
|
if (shouldTimeProgram) {
|
|
start = tf.util.now();
|
|
}
|
|
var texShape = texData.texShape;
|
|
if (texShape == null) {
|
|
texShape = getTextureShapeFromLogicalShape(shape, isPacked);
|
|
texData.texShape = texShape;
|
|
}
|
|
if (values != null) {
|
|
var shapeAs3D = getShapeAs3D(shape);
|
|
var program = void 0;
|
|
var width = texShape[1], height = texShape[0];
|
|
var isByteArray = values instanceof Uint8Array;
|
|
if (isPacked) {
|
|
_a = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), width = _a[0], height = _a[1];
|
|
program = new EncodeMatrixPackedProgram(shapeAs3D, [height, width], isByteArray);
|
|
}
|
|
else {
|
|
program =
|
|
new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray);
|
|
}
|
|
var tempDenseInputHandle = this.makeTensorInfo([height, width], dtype);
|
|
if (isByteArray) {
|
|
this.texData.get(tempDenseInputHandle.dataId).usage =
|
|
TextureUsage.PIXELS;
|
|
}
|
|
else {
|
|
this.texData.get(tempDenseInputHandle.dataId).usage =
|
|
TextureUsage.UPLOAD;
|
|
}
|
|
this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values);
|
|
// We want the output to remain packed regardless of the value of
|
|
// WEBGL_PACK.
|
|
var preventEagerUnpacking = true;
|
|
var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, null, preventEagerUnpacking);
|
|
// Have the original texture assume the identity of the encoded output.
|
|
var outputTexData = this.texData.get(encodedOutputTarget.dataId);
|
|
texData.texture = outputTexData.texture;
|
|
texData.texShape = outputTexData.texShape;
|
|
texData.isPacked = outputTexData.isPacked;
|
|
texData.usage = outputTexData.usage;
|
|
this.disposeIntermediateTensorInfo(tempDenseInputHandle);
|
|
this.texData.delete(encodedOutputTarget.dataId);
|
|
// Once uploaded, don't store the values on cpu.
|
|
texData.values = null;
|
|
if (shouldTimeProgram) {
|
|
this.uploadWaitMs += tf.util.now() - start;
|
|
}
|
|
}
|
|
else {
|
|
var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked);
|
|
texData.texture = newTexture;
|
|
}
|
|
};
|
|
MathBackendWebGL.prototype.convertAndCacheOnCPU = function (dataId, float32Values) {
|
|
var texData = this.texData.get(dataId);
|
|
var dtype = texData.dtype;
|
|
this.releaseGPUData(dataId);
|
|
if (float32Values != null) {
|
|
texData.values = float32ToTypedArray(float32Values, dtype);
|
|
}
|
|
return texData.values;
|
|
};
|
|
MathBackendWebGL.prototype.acquireTexture = function (texShape, texType, dtype, isPacked) {
|
|
this.numBytesInGPU += this.computeBytes(texShape, dtype);
|
|
if (!this.warnedAboutMemory &&
|
|
this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) {
|
|
var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2);
|
|
this.warnedAboutMemory = true;
|
|
console.warn("High memory usage in GPU: " + mb + " MB, " +
|
|
"most likely due to a memory leak");
|
|
}
|
|
return this.textureManager.acquireTexture(texShape, texType, isPacked);
|
|
};
|
|
MathBackendWebGL.prototype.computeBytes = function (shape, dtype) {
|
|
return shape[0] * shape[1] * tf.util.bytesPerElement(dtype);
|
|
};
|
|
MathBackendWebGL.nextDataId = 0;
|
|
return MathBackendWebGL;
|
|
}(tf.KernelBackend));
|
|
function float32ToTypedArray(a, dtype) {
|
|
if (dtype === 'float32' || dtype === 'complex64') {
|
|
return a;
|
|
}
|
|
else if (dtype === 'int32' || dtype === 'bool') {
|
|
var result = (dtype === 'int32') ? new Int32Array(a.length) :
|
|
new Uint8Array(a.length);
|
|
for (var i = 0; i < result.length; ++i) {
|
|
result[i] = Math.round(a[i]);
|
|
}
|
|
return result;
|
|
}
|
|
else {
|
|
throw new Error("Unknown dtype " + dtype);
|
|
}
|
|
}
|
|
|
|
/** @license See the LICENSE file. */
|
|
// This code is auto-generated, do not modify this file!
|
|
var version = '3.7.0';
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* Enforce use of half precision textures if available on the platform.
|
|
*
|
|
* @doc {heading: 'Environment', namespace: 'webgl'}
|
|
*/
|
|
function forceHalfFloat() {
|
|
tf.env().set('WEBGL_FORCE_F16_TEXTURES', true);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google Inc. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
if (tf.device_util.isBrowser()) {
|
|
tf.registerBackend('webgl', function () { return new MathBackendWebGL(); }, 2 /* priority */);
|
|
}
|
|
var webgl = { forceHalfFloat: forceHalfFloat };
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var CHECK_NAN_SNIPPET$1 = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
|
|
var BinaryOpProgram = /** @class */ (function () {
|
|
function BinaryOpProgram(op, aShape, bShape) {
|
|
this.variableNames = ['A', 'B'];
|
|
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
|
this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n ";
|
|
}
|
|
return BinaryOpProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var CHECK_NAN_SNIPPET$2 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
|
|
var BinaryOpPackedProgram = /** @class */ (function () {
|
|
function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) {
|
|
if (checkOutOfBounds === void 0) { checkOutOfBounds = false; }
|
|
this.variableNames = ['A', 'B'];
|
|
this.supportsBroadcasting = true;
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
|
var rank = this.outputShape.length;
|
|
var checkOutOfBoundsString = '';
|
|
if (checkOutOfBounds) {
|
|
if (rank === 0 || tf.util.sizeFromShape(this.outputShape) === 1) {
|
|
checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n ";
|
|
}
|
|
else {
|
|
var dtype = getCoordsDataType(rank);
|
|
checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n ";
|
|
if (rank === 1) {
|
|
checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n ";
|
|
}
|
|
else {
|
|
var channels = getChannels('coords', rank);
|
|
checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n ";
|
|
}
|
|
}
|
|
}
|
|
this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n ";
|
|
}
|
|
return BinaryOpPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function identity(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
backend.incRef(x.dataId);
|
|
return { dataId: x.dataId, shape: x.shape, dtype: x.dtype };
|
|
}
|
|
var identityConfig = {
|
|
kernelName: tf.Identity,
|
|
backendName: 'webgl',
|
|
kernelFunc: identity
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* In WebGL data is stored in GPU textures which can't be efficiently copied, so
|
|
* complex tensors share data with their real and imaginary components. Complex
|
|
* tensors' reference to the components is tracked by refCount on the individual
|
|
* component. The refCounts are increased by the identity call.
|
|
*
|
|
* When a complex tensor is disposed, it will reduce the refCount on the
|
|
* components by calling disposeData on each.
|
|
*/
|
|
function complex(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var real = inputs.real, imag = inputs.imag;
|
|
var complexInfo = backend.makeTensorInfo(real.shape, 'complex64');
|
|
var complex = backend.texData.get(complexInfo.dataId);
|
|
var realTensorInfo = identity({ inputs: { x: real }, backend: backend });
|
|
var imagTensorInfo = identity({ inputs: { x: imag }, backend: backend });
|
|
complex.complexTensorInfos = { real: realTensorInfo, imag: imagTensorInfo };
|
|
return complexInfo;
|
|
}
|
|
var complexConfig = {
|
|
kernelName: tf.Complex,
|
|
backendName: 'webgl',
|
|
kernelFunc: complex
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LEAKYRELU = "return (a < 0.) ? b * a : a;";
|
|
var LEAKYRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
|
|
function leakyRelu(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var alpha = attrs.alpha;
|
|
var $alpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(alpha, 'float32'));
|
|
var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(LEAKYRELU_PACKED, x.shape, $alpha.shape) :
|
|
new BinaryOpProgram(LEAKYRELU, x.shape, $alpha.shape);
|
|
var result = backend.runWebGLProgram(program, [x, $alpha], x.dtype);
|
|
backend.disposeIntermediateTensorInfo($alpha);
|
|
return result;
|
|
}
|
|
var leakyReluConfig = {
|
|
kernelName: tf.LeakyRelu,
|
|
backendName: 'webgl',
|
|
kernelFunc: leakyRelu
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PRELU = "return (a < 0.) ? b * a : a;";
|
|
var PRELU_PACKED = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n";
|
|
function prelu(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x, alpha = inputs.alpha;
|
|
var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(PRELU_PACKED, x.shape, alpha.shape) :
|
|
new BinaryOpProgram(PRELU, x.shape, alpha.shape);
|
|
return backend.runWebGLProgram(program, [x, alpha], x.dtype);
|
|
}
|
|
var preluConfig = {
|
|
kernelName: tf.Prelu,
|
|
backendName: 'webgl',
|
|
kernelFunc: prelu
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var CHECK_NAN_SNIPPET_UNARY = "if (isnan(x)) return x;";
|
|
var CHECK_NAN_SNIPPET_BINARY = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n";
|
|
var CHECK_NAN_SNIPPET_BINARY_PACKED = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n";
|
|
/**
|
|
* Template that creates a `KernelFunc` for unary ops.
|
|
* @param opSnippet Op snippet to create `UnaryOpProgram`.
|
|
* @param packedOpSnippet Op snippet to create `UnaryOpPackedProgram`.
|
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
|
* result has the same dtype as the first input. This is mainly used in
|
|
* comparison kernels, such as Equal, Less, Greater, etc.
|
|
*/
|
|
function unaryKernelFunc(_a) {
|
|
var opSnippet = _a.opSnippet, packedOpSnippet = _a.packedOpSnippet, cpuKernelImpl = _a.cpuKernelImpl, dtype = _a.dtype;
|
|
return function (_a) {
|
|
var inputs = _a.inputs, backend = _a.backend;
|
|
var x = inputs.x;
|
|
var webglBackend = backend;
|
|
var $dtype = dtype || x.dtype;
|
|
if (webglBackend.shouldExecuteOnCPU([x]) && cpuKernelImpl != null) {
|
|
var xData = webglBackend.texData.get(x.dataId);
|
|
var outValues = cpuKernelImpl(xData.values, $dtype);
|
|
return webglBackend.makeTensorInfo(x.shape, $dtype, outValues);
|
|
}
|
|
var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS') && packedOpSnippet != null;
|
|
var program;
|
|
if (shouldUsePackedProgram) {
|
|
program = new UnaryOpPackedProgram(x.shape, packedOpSnippet);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, opSnippet);
|
|
}
|
|
return webglBackend.runWebGLProgram(program, [x], $dtype);
|
|
};
|
|
}
|
|
/**
|
|
* Template that creates a `KernelFunc` for binary ops.
|
|
* @param opSnippet Op snippet to create `BinaryOpProgram`.
|
|
* @param packedOpSnippet Op snippet to create `BinaryOpPackedProgram`.
|
|
* @param checkOutOfBoundsForPackedProgram Whether to set checkOutOfBounds=true
|
|
* when creating BinaryOpPackedProgram.
|
|
* @param dtype Optional. If set, the result has this dtype. Otherwise, the
|
|
* result has the same dtype as the first input. This is mainly used in
|
|
* comparison kernels, such as Equal, Less, Greater, etc.
|
|
*/
|
|
function binaryKernelFunc(_a) {
|
|
var opSnippet = _a.opSnippet, packedOpSnippet = _a.packedOpSnippet, _b = _a.checkOutOfBounds, checkOutOfBounds = _b === void 0 ? false : _b, _c = _a.supportsComplex, supportsComplex = _c === void 0 ? false : _c, cpuKernelImpl = _a.cpuKernelImpl, dtype = _a.dtype;
|
|
return function (_a) {
|
|
var inputs = _a.inputs, backend = _a.backend;
|
|
var _b = inputs, a = _b.a, b = _b.b;
|
|
var webglBackend = backend;
|
|
if (supportsComplex && a.dtype === 'complex64') {
|
|
var aData = webglBackend.texData.get(a.dataId);
|
|
var bData = webglBackend.texData.get(b.dataId);
|
|
var _c = [
|
|
[aData.complexTensorInfos.real, bData.complexTensorInfos.real],
|
|
[aData.complexTensorInfos.imag, bData.complexTensorInfos.imag]
|
|
].map(function (complexParts) {
|
|
var aPart = complexParts[0], bPart = complexParts[1];
|
|
var aHandle = {
|
|
dataId: aPart.dataId,
|
|
dtype: aPart.dtype,
|
|
shape: a.shape
|
|
};
|
|
var bHandle = {
|
|
dataId: bPart.dataId,
|
|
dtype: bPart.dtype,
|
|
shape: b.shape
|
|
};
|
|
var program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
|
|
return webglBackend.runWebGLProgram(program, [aHandle, bHandle], tf.upcastType(aPart.dtype, bPart.dtype));
|
|
}), real = _c[0], imag = _c[1];
|
|
var complexOutput = complex({ inputs: { real: real, imag: imag }, backend: webglBackend });
|
|
webglBackend.disposeIntermediateTensorInfo(real);
|
|
webglBackend.disposeIntermediateTensorInfo(imag);
|
|
// TODO(annxingyuan): Implement CPU forwarding for complex inputs.
|
|
return complexOutput;
|
|
}
|
|
var $dtype = dtype || tf.upcastType(a.dtype, b.dtype);
|
|
if ((a.dtype === 'string' || b.dtype === 'string' ||
|
|
webglBackend.shouldExecuteOnCPU([a, b])) &&
|
|
cpuKernelImpl != null) {
|
|
var aVals = webglBackend.texData.get(a.dataId).values;
|
|
var bVals = webglBackend.texData.get(b.dataId).values;
|
|
var decodedAVals = a.dtype === 'string' ?
|
|
// tslint:disable-next-line: no-any
|
|
tf.backend_util.fromUint8ToStringArray(aVals) :
|
|
aVals;
|
|
var decodedBVals = a.dtype === 'string' ?
|
|
// tslint:disable-next-line: no-any
|
|
tf.backend_util.fromUint8ToStringArray(bVals) :
|
|
bVals;
|
|
var _d = cpuKernelImpl(a.shape, b.shape, decodedAVals, decodedBVals, $dtype), outValues = _d[0], outShape = _d[1];
|
|
var out = webglBackend.makeTensorInfo(outShape, $dtype);
|
|
var outData = webglBackend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
return out;
|
|
}
|
|
var shouldUsePackedProgram = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') &&
|
|
packedOpSnippet != null;
|
|
var program;
|
|
if (shouldUsePackedProgram) {
|
|
program = new BinaryOpPackedProgram(packedOpSnippet, a.shape, b.shape, checkOutOfBounds);
|
|
}
|
|
else {
|
|
program = new BinaryOpProgram(opSnippet, a.shape, b.shape);
|
|
}
|
|
return webglBackend.runWebGLProgram(program, [a, b], $dtype);
|
|
};
|
|
}
|
|
function mapActivationToShaderProgram(activation, packed) {
|
|
if (packed === void 0) { packed = false; }
|
|
if (activation === 'linear') {
|
|
if (packed) {
|
|
return LINEAR$1;
|
|
}
|
|
return LINEAR;
|
|
}
|
|
else if (activation === 'relu') {
|
|
if (packed) {
|
|
return RELU$1;
|
|
}
|
|
return RELU;
|
|
}
|
|
else if (activation === 'elu') {
|
|
if (packed) {
|
|
return ELU$1;
|
|
}
|
|
return ELU;
|
|
}
|
|
else if (activation === 'relu6') {
|
|
if (packed) {
|
|
return RELU6$1;
|
|
}
|
|
return RELU6;
|
|
}
|
|
else if (activation === 'prelu') {
|
|
if (packed) {
|
|
return PRELU_PACKED;
|
|
}
|
|
return PRELU;
|
|
}
|
|
else if (activation === 'leakyrelu') {
|
|
if (packed) {
|
|
return LEAKYRELU_PACKED;
|
|
}
|
|
return LEAKYRELU;
|
|
}
|
|
else if (activation === 'sigmoid') {
|
|
if (packed) {
|
|
return SIGMOID$1;
|
|
}
|
|
return SIGMOID;
|
|
}
|
|
throw new Error("Activation " + activation + " has not been implemented for the WebGL backend.");
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MatMulPackedProgram = /** @class */ (function () {
|
|
function MatMulPackedProgram(aShape, bShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation, hasLeakyreluActivation) {
|
|
if (transposeA === void 0) { transposeA = false; }
|
|
if (transposeB === void 0) { transposeB = false; }
|
|
if (addBias === void 0) { addBias = false; }
|
|
if (activation === void 0) { activation = null; }
|
|
if (hasPreluActivation === void 0) { hasPreluActivation = false; }
|
|
if (hasLeakyreluActivation === void 0) { hasLeakyreluActivation = false; }
|
|
this.variableNames = ['matrixA', 'matrixB'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
var sharedDim = transposeA ? aShape[1] : aShape[2];
|
|
var sharedDimensionPacked = Math.ceil(sharedDim / 2);
|
|
var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2';
|
|
var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z';
|
|
var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww'];
|
|
var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw'];
|
|
var activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else if (hasLeakyreluActivation) {
|
|
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else {
|
|
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
|
|
}
|
|
applyActivationSnippet = "result = activation(result);";
|
|
}
|
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyreluActivation) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
var batchASnippet = 'rc.x';
|
|
var batchBSnippet = 'rc.x';
|
|
if (aShape[0] < bShape[0]) {
|
|
batchASnippet = "int(min(float(rc.x), " + (aShape[0] - 1) + ".))";
|
|
}
|
|
else if (bShape[0] < aShape[0]) {
|
|
batchBSnippet = "int(min(float(rc.x), " + (bShape[0] - 1) + ".))";
|
|
}
|
|
this.userCode = "\n " + activationSnippet + "\n\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n int batchA = " + batchASnippet + ";\n int batchB = " + batchBSnippet + ";\n vec4 a = getMatrixA(batchA, " + aSample + ");\n vec4 b = getMatrixB(batchB, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n ";
|
|
}
|
|
return MatMulPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// (Ar + Ai)(Br + Bi) =
|
|
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
|
|
// Yr = ArBr - AB
|
|
// Yi = ArBi + AiBr
|
|
var COMPLEX_MULTIPLY = {
|
|
REAL: 'return areal * breal - aimag * bimag;',
|
|
IMAG: 'return areal * bimag + aimag * breal;'
|
|
};
|
|
var BinaryOpComplexProgram = /** @class */ (function () {
|
|
function BinaryOpComplexProgram(op, aShape, bShape) {
|
|
this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag'];
|
|
this.outputShape = tf.backend_util.assertAndGetBroadcastShape(aShape, bShape);
|
|
this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n ";
|
|
}
|
|
return BinaryOpComplexProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MUL = 'return a * b;';
|
|
function multiply(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var a = inputs.a, b = inputs.b;
|
|
var dtype = tf.backend_util.upcastType(a.dtype, b.dtype);
|
|
if (a.dtype === 'complex64') {
|
|
var aData = backend.texData.get(a.dataId);
|
|
var bData = backend.texData.get(b.dataId);
|
|
var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape);
|
|
var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape);
|
|
var inputs_1 = [
|
|
{
|
|
dataId: aData.complexTensorInfos.real.dataId,
|
|
dtype: aData.complexTensorInfos.real.dtype,
|
|
shape: a.shape
|
|
},
|
|
{
|
|
dataId: aData.complexTensorInfos.imag.dataId,
|
|
dtype: aData.complexTensorInfos.imag.dtype,
|
|
shape: a.shape
|
|
},
|
|
{
|
|
dataId: bData.complexTensorInfos.real.dataId,
|
|
dtype: bData.complexTensorInfos.real.dtype,
|
|
shape: b.shape
|
|
},
|
|
{
|
|
dataId: bData.complexTensorInfos.imag.dataId,
|
|
dtype: bData.complexTensorInfos.imag.dtype,
|
|
shape: b.shape
|
|
}
|
|
];
|
|
var realPart = backend.runWebGLProgram(realProgram, inputs_1, 'float32');
|
|
var imagPart = backend.runWebGLProgram(imagProgram, inputs_1, 'float32');
|
|
var complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
// TODO(annxingyuan): CPU forwarding for complex inputs.
|
|
return complexOutput;
|
|
}
|
|
if (backend.shouldExecuteOnCPU([a, b])) {
|
|
var aData = backend.texData.get(a.dataId);
|
|
var bData = backend.texData.get(b.dataId);
|
|
var _a = multiplyImplCPU(a.shape, b.shape, aData.values, bData.values, dtype), outValues = _a[0], outShape = _a[1];
|
|
var out = backend.makeTensorInfo(outShape, dtype);
|
|
var outData = backend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
return out;
|
|
}
|
|
var program;
|
|
if (tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) {
|
|
program = new BinaryOpPackedProgram(MUL, a.shape, b.shape);
|
|
}
|
|
else {
|
|
program = new BinaryOpProgram(MUL, a.shape, b.shape);
|
|
}
|
|
return backend.runWebGLProgram(program, [a, b], dtype);
|
|
}
|
|
var multiplyConfig = {
|
|
kernelName: tf.Multiply,
|
|
backendName: 'webgl',
|
|
kernelFunc: multiply
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function packedReshape(input, afterShape, backend) {
|
|
var input3DShape = [getBatchDim(input.shape)].concat(getRowsCols(input.shape));
|
|
var input3D = {
|
|
dtype: input.dtype,
|
|
shape: input3DShape,
|
|
dataId: input.dataId
|
|
};
|
|
var afterShapeAs3D = [getBatchDim(afterShape)].concat(getRowsCols(afterShape));
|
|
var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape);
|
|
var preventEagerUnpackingOfOutput = true;
|
|
var output = backend.runWebGLProgram(program, [input3D], input.dtype, null /* customSetup */, preventEagerUnpackingOfOutput);
|
|
return { dataId: output.dataId, shape: afterShape, dtype: output.dtype };
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function reshape(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var shape = attrs.shape;
|
|
var webglBackend = backend;
|
|
var xSize = tf.util.sizeFromShape(x.shape);
|
|
var $shape = tf.util.inferFromImplicitShape(shape, xSize);
|
|
var $xSize = tf.util.sizeFromShape($shape);
|
|
tf.util.assert(xSize === $xSize, function () { return "The new shape (" + $shape + ") has " + $xSize + " elements and the old " +
|
|
("shape (" + x.shape + ") has " + xSize + " elements. The new shape and old ") +
|
|
"shape must have the same number of elements."; });
|
|
var xTexData = webglBackend.texData.get(x.dataId);
|
|
if (xTexData.isPacked && !isReshapeFree(x.shape, $shape) &&
|
|
!(xTexData.texture !== null && isReshapeFree(xTexData.shape, $shape))) {
|
|
return packedReshape(x, $shape, webglBackend);
|
|
}
|
|
webglBackend.incRef(x.dataId);
|
|
return { dataId: x.dataId, shape: $shape, dtype: x.dtype };
|
|
}
|
|
var reshapeConfig = {
|
|
kernelName: tf.Reshape,
|
|
backendName: 'webgl',
|
|
kernelFunc: reshape
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MeanProgram = /** @class */ (function () {
|
|
function MeanProgram(reduceInfo, divisor) {
|
|
this.variableNames = ['x'];
|
|
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize;
|
|
this.outputShape = [batchSize, outSize];
|
|
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
var windowSizeVec4Remainder = windowSize % 4;
|
|
var updateSnippet = "sumValue += dot(values, ones);";
|
|
if (divisor != null) {
|
|
var denominator = 1 / divisor;
|
|
updateSnippet = "sumValue += dot(values * " + (tf.util.isInt(denominator) ? denominator.toPrecision(2) :
|
|
denominator) + ", ones);";
|
|
}
|
|
var checkOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return 0.0;\n }\n ";
|
|
}
|
|
this.userCode = "\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(getValue(batch, inIdx), 0.0, 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1), 0.0, 0.0);\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2), 0.0);\n\n " + updateSnippet + "\n }\n setOutput(sumValue);\n }\n ";
|
|
}
|
|
return MeanProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ReduceProgram = /** @class */ (function () {
|
|
function ReduceProgram(reduceInfo, reduceType) {
|
|
this.variableNames = ['x'];
|
|
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, inSize = reduceInfo.inSize, outSize = reduceInfo.outSize;
|
|
this.outputShape = [batchSize, outSize];
|
|
var initializationValue = '0.0';
|
|
var compareOp = "";
|
|
if (reduceType === 'prod') {
|
|
initializationValue = '1.0';
|
|
}
|
|
else if (reduceType === 'min') {
|
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
|
initializationValue = '1.0 / 1e-20';
|
|
compareOp = "min";
|
|
}
|
|
else if (reduceType === 'max') {
|
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
|
initializationValue = '-1.0 / 1e-20';
|
|
compareOp = "max";
|
|
}
|
|
var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (reduceType === 'sum') {
|
|
returnValue = "sumValue";
|
|
}
|
|
else if (reduceType === 'prod') {
|
|
returnValue = "prodValue";
|
|
}
|
|
else if (reduceType === 'all') {
|
|
returnValue = "allValue";
|
|
}
|
|
else if (reduceType === 'any') {
|
|
returnValue = "anyValue";
|
|
}
|
|
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
var windowSizeVec4Remainder = windowSize % 4;
|
|
var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n if (" + (reduceType === 'min') + " || " + (reduceType === 'max') + ") {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n bvec4 isNaN = isnan(values);\n if (isNaN.r || isNaN.g || isNaN.b || isNaN.a) {\n minMaxValue = vec4(NAN);\n }\n }\n }\n ";
|
|
var vecType = "vec4";
|
|
if (reduceType === 'all') {
|
|
initializationValue = '1.0';
|
|
updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n ";
|
|
vecType = "bvec4";
|
|
}
|
|
else if (reduceType === 'any') {
|
|
initializationValue = '0.0';
|
|
updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n ";
|
|
vecType = "bvec4";
|
|
}
|
|
var checkOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
|
|
}
|
|
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
|
|
}
|
|
return ReduceProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// Returns an array of configuration objects that describe each stage of the
|
|
// reduction.
|
|
function getReductionStages(inShape) {
|
|
var stages = [];
|
|
while (stages.length === 0 || stages[stages.length - 1].outSize !== 1) {
|
|
var outSize = stages.length ? stages[stages.length - 1].outSize : inShape[1];
|
|
var windowSize = tf.backend_util.computeOptimalWindowSize(outSize);
|
|
stages.push({
|
|
inSize: outSize,
|
|
windowSize: windowSize,
|
|
outSize: Math.ceil(outSize / windowSize)
|
|
});
|
|
}
|
|
return stages;
|
|
}
|
|
function reduce(x, dtype, reductionType, backend) {
|
|
var reductionStages = getReductionStages(x.shape);
|
|
var result = x;
|
|
for (var i = 0; i < reductionStages.length; i++) {
|
|
var _a = reductionStages[i], inSize = _a.inSize, windowSize = _a.windowSize, outSize = _a.outSize;
|
|
var program = void 0;
|
|
var previousResult = void 0;
|
|
if (reductionType === 'mean') {
|
|
program = i === 0 ?
|
|
new MeanProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize }, inSize) :
|
|
new MeanProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize });
|
|
}
|
|
else {
|
|
program = new ReduceProgram({ windowSize: windowSize, inSize: inSize, batchSize: x.shape[0], outSize: outSize }, reductionType);
|
|
}
|
|
previousResult = result;
|
|
result = backend.runWebGLProgram(program, [result], dtype);
|
|
if (previousResult.dataId !== x.dataId) {
|
|
backend.disposeIntermediateTensorInfo(previousResult);
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TransposeProgram = /** @class */ (function () {
|
|
function TransposeProgram(aShape, newDim) {
|
|
this.variableNames = ['A'];
|
|
var outputShape = new Array(aShape.length);
|
|
for (var i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[newDim[i]];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var switched = getSwitchedCoords(newDim);
|
|
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + switched + "));\n }\n ";
|
|
}
|
|
return TransposeProgram;
|
|
}());
|
|
function getSwitchedCoords(newDim) {
|
|
var rank = newDim.length;
|
|
if (rank > 6) {
|
|
throw Error("Transpose for rank " + rank + " is not yet supported");
|
|
}
|
|
var originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'];
|
|
var switchedCoords = new Array(rank);
|
|
for (var i = 0; i < newDim.length; i++) {
|
|
switchedCoords[newDim[i]] = originalOrder[i];
|
|
}
|
|
return switchedCoords.join();
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TransposePackedProgram = /** @class */ (function () {
|
|
function TransposePackedProgram(aShape, newDim) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
var outputShape = new Array(aShape.length);
|
|
for (var i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[newDim[i]];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
if (this.rank > 6) {
|
|
throw Error("Packed transpose for rank " + this.rank + " is not yet supported.");
|
|
}
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var outputOrder = getVecChannels('rc', this.rank);
|
|
var switchedOrder = new Array(this.rank);
|
|
for (var i = 0; i < newDim.length; i++) {
|
|
switchedOrder[newDim[i]] = outputOrder[i];
|
|
}
|
|
var innerDims = "vec2(" + switchedOrder.slice(-2).join() + ")";
|
|
var nextColumn = "++" + outputOrder[this.rank - 1] + " < " + outputShape[this.rank - 1];
|
|
var getc = "getChannel(getA(" + switchedOrder.join() + "), " + innerDims + ")";
|
|
this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = " + getc + ";\n if(" + nextColumn + ") {\n result[1] = " + getc + ";\n }\n --" + outputOrder[this.rank - 1] + ";\n if(++" + outputOrder[this.rank - 2] + " < " + outputShape[this.rank - 2] + ") {\n result[2] = " + getc + ";\n if(" + nextColumn + ") {\n result[3] = " + getc + ";\n }\n }\n setOutput(result);\n }\n ";
|
|
}
|
|
return TransposePackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function transposeImpl$1(x, perm, backend) {
|
|
var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new TransposePackedProgram(x.shape, perm) :
|
|
new TransposeProgram(x.shape, perm);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sumImpl(x, axis, keepDims, backend) {
|
|
var reductionIndices = axis;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(reductionIndices, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var sumInputIsTransposed = permutedAxes != null;
|
|
var sumInput = x;
|
|
if (sumInputIsTransposed) {
|
|
sumInput = transposeImpl$1(x, permutedAxes, backend);
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank);
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(sumInput.shape, axes), sumOutShape = _a[0], reduceShape = _a[1];
|
|
var outShape = sumOutShape;
|
|
if (keepDims) {
|
|
// rather than reshape at the end, set the target shape here.
|
|
outShape = tf.backend_util.expandShapeToKeepDim(sumOutShape, origAxes);
|
|
}
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var xSize = tf.util.sizeFromShape(x.shape);
|
|
var batchSize = xSize / inSize;
|
|
var reshapedInput = reshape({ inputs: { x: sumInput }, attrs: { shape: [batchSize, inSize] }, backend: backend });
|
|
var outType = tf.sumOutType(x.dtype);
|
|
var reduced = reduce(reshapedInput, outType, 'sum', backend);
|
|
var out = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (sumInputIsTransposed) {
|
|
backend.disposeIntermediateTensorInfo(sumInput);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sum(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, keepDims = attrs.keepDims;
|
|
return sumImpl(x, axis, keepDims, backend);
|
|
}
|
|
var sumConfig = {
|
|
kernelName: tf.Sum,
|
|
backendName: 'webgl',
|
|
kernelFunc: sum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function transpose(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var perm = attrs.perm;
|
|
var webglBackend = backend;
|
|
var xRank = x.shape.length;
|
|
var newShape = new Array(xRank);
|
|
for (var i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[perm[i]];
|
|
}
|
|
var out;
|
|
if (webglBackend.shouldExecuteOnCPU([x])) {
|
|
var xTexData = webglBackend.texData.get(x.dataId);
|
|
var values = xTexData.values;
|
|
var outValues = transposeImplCPU(values, x.shape, x.dtype, perm, newShape);
|
|
out = webglBackend.makeTensorInfo(newShape, x.dtype);
|
|
var outData = webglBackend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
}
|
|
else {
|
|
out = transposeImpl$1(x, perm, webglBackend);
|
|
}
|
|
return out;
|
|
}
|
|
var transposeConfig = {
|
|
kernelName: tf.Transpose,
|
|
backendName: 'webgl',
|
|
kernelFunc: transpose
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// Empirically determined minimal shared dimension in matmul before we forward
|
|
// to a.mul(b).sum() in order to take advantage of GPU parallelism. See
|
|
// https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks.
|
|
var MATMUL_SHARED_DIM_THRESHOLD = 1000;
|
|
function batchMatMulImpl(_a) {
|
|
var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, backend = _a.backend, _b = _a.bias, bias = _b === void 0 ? null : _b, _c = _a.preluActivationWeights, preluActivationWeights = _c === void 0 ? null : _c, _d = _a.leakyreluAlpha, leakyreluAlpha = _d === void 0 ? 0 : _d, _e = _a.activation, activation = _e === void 0 ? null : _e;
|
|
var aRank = a.shape.length;
|
|
var bRank = b.shape.length;
|
|
var innerShapeA = transposeA ? a.shape[aRank - 2] : a.shape[aRank - 1];
|
|
var innerShapeB = transposeB ? b.shape[bRank - 1] : b.shape[bRank - 2];
|
|
var outerShapeA = transposeA ? a.shape[aRank - 1] : a.shape[aRank - 2];
|
|
var outerShapeB = transposeB ? b.shape[bRank - 2] : b.shape[bRank - 1];
|
|
var outerDimsA = a.shape.slice(0, -2);
|
|
var outerDimsB = b.shape.slice(0, -2);
|
|
var batchDimA = tf.util.sizeFromShape(outerDimsA);
|
|
var batchDimB = tf.util.sizeFromShape(outerDimsB);
|
|
var batchDimsCompatible = batchDimA === batchDimB || batchDimA === 1 || batchDimB === 1;
|
|
tf.util.assert(aRank >= 2 && bRank >= 2 && batchDimsCompatible, function () { return "Error in matMul: the input batch dimensions must either be the " +
|
|
"same or at least one input batch dimension must be 1. Got input " +
|
|
("batch dimensions of (" + outerDimsA + ") and (" + outerDimsB + ")."); });
|
|
var outShapeOuterDims = batchDimA > batchDimB ? a.shape.slice(0, -2) : b.shape.slice(0, -2);
|
|
var outShape = outShapeOuterDims.concat([outerShapeA, outerShapeB]);
|
|
tf.util.assert(innerShapeA === innerShapeB, function () { return "Error in matMul: inner shapes (" + innerShapeA + ") and (" +
|
|
(innerShapeB + ") of Tensors with shapes " + a.shape + " and ") +
|
|
(b.shape + " and transposeA=" + transposeA) +
|
|
(" and transposeB=" + transposeB + " must match."); });
|
|
var a3dShape = transposeA ?
|
|
[batchDimA, innerShapeA, outerShapeA] :
|
|
[batchDimA, outerShapeA, innerShapeA];
|
|
var b3dShape = transposeB ?
|
|
[batchDimB, outerShapeB, innerShapeB] :
|
|
[batchDimB, innerShapeB, outerShapeB];
|
|
// The rest of the implementation is designed to operate on rank-3 tensors
|
|
var a3d = reshape({ inputs: { x: a }, backend: backend, attrs: { shape: a3dShape } });
|
|
var b3d = reshape({ inputs: { x: b }, backend: backend, attrs: { shape: b3dShape } });
|
|
var intermediates = [a3d, b3d];
|
|
var batchDim = Math.max(batchDimA, batchDimB);
|
|
var sharedDim = transposeA ? a3d.shape[1] : a3d.shape[2];
|
|
var hasBias = bias != null;
|
|
var hasPreluActivationWeights = preluActivationWeights != null;
|
|
var hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
var fusedActivation = activation != null ?
|
|
mapActivationToShaderProgram(activation, true) :
|
|
null;
|
|
var containsFusedOps = hasBias || hasPreluActivationWeights ||
|
|
hasLeakyreluAlpha || fusedActivation != null;
|
|
var out;
|
|
// Since the matrices are vectors, it is faster to call mul().sum()
|
|
// because sum() is O(sqrt(N)) due to divide-and-conquer.
|
|
if ((outerShapeA === 1 || outerShapeB === 1) &&
|
|
sharedDim > MATMUL_SHARED_DIM_THRESHOLD && containsFusedOps === false) {
|
|
var aVec = a3d;
|
|
var bVec = b3d;
|
|
if (transposeA) {
|
|
aVec = transpose({ inputs: { x: a3d }, backend: backend, attrs: { perm: [0, 2, 1] } });
|
|
intermediates.push(aVec);
|
|
}
|
|
if (transposeB) {
|
|
bVec = transpose({ inputs: { x: b3d }, backend: backend, attrs: { perm: [0, 2, 1] } });
|
|
intermediates.push(bVec);
|
|
}
|
|
var shouldReshapeA = outerShapeB !== 1;
|
|
var shouldReshapeB = outerShapeB === 1;
|
|
var aVec3d = aVec;
|
|
if (shouldReshapeA) {
|
|
aVec3d = reshape({
|
|
inputs: { x: aVec },
|
|
backend: backend,
|
|
attrs: { shape: [batchDim, sharedDim, 1] }
|
|
});
|
|
intermediates.push(aVec3d);
|
|
}
|
|
var axis = outerShapeB === 1 ? 2 : 1;
|
|
var bVec3d = bVec;
|
|
if (shouldReshapeB) {
|
|
bVec3d = reshape({
|
|
inputs: { x: bVec },
|
|
backend: backend,
|
|
attrs: { shape: [batchDim, 1, sharedDim] }
|
|
});
|
|
intermediates.push(bVec3d);
|
|
}
|
|
var product = multiply({ inputs: { a: aVec3d, b: bVec3d }, backend: backend });
|
|
out = sum({ inputs: { x: product }, backend: backend, attrs: { axis: axis, keepDims: true } });
|
|
intermediates.push(product);
|
|
}
|
|
else {
|
|
var dtype = tf.upcastType(a.dtype, b.dtype);
|
|
var program = new MatMulPackedProgram(a3dShape, b3dShape, [batchDim, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
var inputs = [a3d, b3d];
|
|
if (bias != null) {
|
|
inputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
inputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
out = backend.runWebGLProgram(program, inputs, dtype);
|
|
}
|
|
var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: outShape } });
|
|
intermediates.push(out);
|
|
for (var _i = 0, intermediates_1 = intermediates; _i < intermediates_1.length; _i++) {
|
|
var i = intermediates_1[_i];
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return outReshaped;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function _fusedMatMul(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var a = inputs.a, b = inputs.b, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
|
|
var transposeA = attrs.transposeA, transposeB = attrs.transposeB, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha;
|
|
return batchMatMulImpl({
|
|
a: a,
|
|
b: b,
|
|
transposeA: transposeA,
|
|
transposeB: transposeB,
|
|
backend: backend,
|
|
bias: bias,
|
|
preluActivationWeights: preluActivationWeights,
|
|
leakyreluAlpha: leakyreluAlpha,
|
|
activation: activation
|
|
});
|
|
}
|
|
var _fusedMatMulConfig = {
|
|
kernelName: tf._FusedMatMul,
|
|
backendName: 'webgl',
|
|
kernelFunc: _fusedMatMul,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ABS$1 = "return abs(x);";
|
|
function abs(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
// TODO: handle cases when x is complex. Once the cpu implementation
|
|
// can handle complex values, refactor to use unaryKernelFunc.
|
|
if (backend.shouldExecuteOnCPU([x]) && x.dtype !== 'complex64') {
|
|
var xData = backend.texData.get(x.dataId);
|
|
var outValues = simpleAbsImplCPU(xData.values);
|
|
return backend.makeTensorInfo(x.shape, x.dtype, outValues);
|
|
}
|
|
var program;
|
|
if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
program = new UnaryOpPackedProgram(x.shape, ABS$1);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, ABS$1);
|
|
}
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
var absConfig = {
|
|
kernelName: tf.Abs,
|
|
backendName: 'webgl',
|
|
kernelFunc: abs
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ACOS = CHECK_NAN_SNIPPET + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n";
|
|
var acos = unaryKernelFunc({ opSnippet: ACOS });
|
|
var acosConfig = {
|
|
kernelName: tf.Acos,
|
|
backendName: 'webgl',
|
|
kernelFunc: acos,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ACOSH = CHECK_NAN_SNIPPET + "\n if (x < 1.0) return NAN;\nreturn log(x + sqrt(x * x - 1.0));";
|
|
var acosh = unaryKernelFunc({ opSnippet: ACOSH });
|
|
var acoshConfig = {
|
|
kernelName: tf.Acosh,
|
|
backendName: 'webgl',
|
|
kernelFunc: acosh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ADD = 'return a + b;';
|
|
var addKernelFunc = binaryKernelFunc({
|
|
opSnippet: ADD,
|
|
packedOpSnippet: ADD,
|
|
supportsComplex: true,
|
|
cpuKernelImpl: addImplCPU
|
|
});
|
|
var addConfig = {
|
|
kernelName: tf.Add,
|
|
backendName: 'webgl',
|
|
kernelFunc: addKernelFunc
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var AddNProgram = /** @class */ (function () {
|
|
function AddNProgram(outputShape, shapes) {
|
|
this.outputShape = [];
|
|
this.outputShape = outputShape;
|
|
this.variableNames = shapes.map(function (_, i) { return "T" + i; });
|
|
var snippets = [];
|
|
// Get target elements from every input tensor.
|
|
this.variableNames.forEach(function (variable) {
|
|
snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();");
|
|
});
|
|
// Calculate the sum of all elements.
|
|
var operation = this.variableNames
|
|
.map(function (variable) {
|
|
return "v" + variable;
|
|
})
|
|
.join(' + ');
|
|
this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n ";
|
|
}
|
|
return AddNProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var AddNPackedProgram = /** @class */ (function () {
|
|
function AddNPackedProgram(outputShape, shapes) {
|
|
this.outputShape = [];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
this.variableNames = shapes.map(function (_, i) { return "T" + i; });
|
|
var snippets = [];
|
|
// Get target elements from every input tensor.
|
|
this.variableNames.forEach(function (variable) {
|
|
snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();");
|
|
});
|
|
// Calculate the sum of all elements.
|
|
var operation = this.variableNames
|
|
.map(function (variable) {
|
|
return "v" + variable;
|
|
})
|
|
.join(' + ');
|
|
this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n ";
|
|
}
|
|
return AddNPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function addN(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var tensors = inputs;
|
|
if (tensors.length === 1) {
|
|
return identity({ inputs: { x: tensors[0] }, backend: backend });
|
|
}
|
|
// Limit the number of uploaded textures for optimization.
|
|
if (tensors.length > tf.env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) {
|
|
var midIndex = Math.floor(tensors.length / 2);
|
|
var leftSide = addN({ inputs: tensors.slice(0, midIndex), backend: backend });
|
|
var rightSide = addN({ inputs: tensors.slice(midIndex), backend: backend });
|
|
return addN({ inputs: [leftSide, rightSide], backend: backend });
|
|
}
|
|
var dtype = tensors.map(function (t) { return t.dtype; }).reduce(function (d1, d2) { return tf.upcastType(d1, d2); });
|
|
var shapes = tensors.map(function (t) { return t.shape; });
|
|
// We can make sure shapes are identical in op level.
|
|
var usePackedOp = tf.env().getBool('WEBGL_PACK');
|
|
var program = usePackedOp ?
|
|
new AddNPackedProgram(tensors[0].shape, shapes) :
|
|
new AddNProgram(tensors[0].shape, shapes);
|
|
return backend.runWebGLProgram(program, tensors, dtype);
|
|
}
|
|
var addNConfig = {
|
|
kernelName: tf.AddN,
|
|
backendName: 'webgl',
|
|
kernelFunc: addN
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function all(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, keepDims = attrs.keepDims;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(axis, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('all', axes, xRank);
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), outShape = _a[0], reduceShape = _a[1];
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
var reduced = reduce(a2D, a2D.dtype, 'all', backend);
|
|
var res;
|
|
if (keepDims) {
|
|
var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
var allConfig = {
|
|
kernelName: tf.All,
|
|
backendName: 'webgl',
|
|
kernelFunc: all
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function any(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, keepDims = attrs.keepDims;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(axis, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('any', axes, xRank);
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), outShape = _a[0], reduceShape = _a[1];
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
var reduced = reduce(a2D, a2D.dtype, 'any', backend);
|
|
var res;
|
|
if (keepDims) {
|
|
var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
var anyConfig = {
|
|
kernelName: tf.Any,
|
|
backendName: 'webgl',
|
|
kernelFunc: any
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ArgMinMaxProgram = /** @class */ (function () {
|
|
function ArgMinMaxProgram(reduceInfo, op, firstPass) {
|
|
this.variableNames = ['A'];
|
|
var windowSize = reduceInfo.windowSize, batchSize = reduceInfo.batchSize, outSize = reduceInfo.outSize;
|
|
if (!firstPass) {
|
|
this.variableNames.push('bestIndicesA');
|
|
}
|
|
this.outputShape = [batchSize, outSize];
|
|
var compOp = (op === 'max') ? '>' : '<';
|
|
var indexSnippet = firstPass ?
|
|
'inOffset + i;' :
|
|
'round(getBestIndicesA(batch, inOffset + i));';
|
|
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n ";
|
|
}
|
|
return ArgMinMaxProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ArgMinMaxPackedProgram = /** @class */ (function () {
|
|
function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
tf.util.assert(shape.length > 2, function () { return "Packed arg" + (op.charAt(0).toUpperCase() +
|
|
op.slice(1)) + " supports only inputs with rank above 2."; });
|
|
var inSize = shape[shape.length - 1];
|
|
var outSize = Math.ceil(inSize / windowSize);
|
|
this.outputShape = shape.slice(0, -1);
|
|
if (outSize > 1) {
|
|
this.outputShape.push(outSize);
|
|
}
|
|
if (!firstPass) {
|
|
this.variableNames.push('bestIndicesA');
|
|
}
|
|
var outShape = this.outputShape;
|
|
var rank = outShape.length;
|
|
var dtype = getCoordsDataType(rank);
|
|
var coords = getChannels('coords', rank);
|
|
var sourceLocSetup;
|
|
var sourceRank;
|
|
if (outSize === 1) {
|
|
sourceRank = rank + 1;
|
|
var sourceLocDType = getCoordsDataType(sourceRank);
|
|
sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";";
|
|
}
|
|
else {
|
|
sourceRank = rank;
|
|
sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";";
|
|
}
|
|
var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank);
|
|
var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3.
|
|
var intChannels = channels.map(function (x) { return 'int ' + x; });
|
|
var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r');
|
|
var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g');
|
|
var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b');
|
|
var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a');
|
|
var compOp = (op === 'max') ? 'greaterThan' : 'lessThan';
|
|
var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));";
|
|
var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)";
|
|
var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }";
|
|
this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n ";
|
|
}
|
|
return ArgMinMaxPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function argReduce(backend, x, reduceType, bestIndicesA) {
|
|
if (bestIndicesA === void 0) { bestIndicesA = null; }
|
|
var batchSize = x.shape[0];
|
|
var inSize = x.shape[1];
|
|
if (bestIndicesA != null) {
|
|
batchSize = bestIndicesA.shape[0];
|
|
inSize = bestIndicesA.shape[1];
|
|
}
|
|
var windowSize = tf.backend_util.computeOptimalWindowSize(inSize);
|
|
var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, outSize: Math.ceil(inSize / windowSize) };
|
|
var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null);
|
|
var inputs = [x];
|
|
if (bestIndicesA != null) {
|
|
inputs.push(bestIndicesA);
|
|
}
|
|
var output = backend.runWebGLProgram(program, inputs, 'int32');
|
|
// No need to run another GPGPU program.
|
|
if (output.shape[1] === 1) {
|
|
return output;
|
|
}
|
|
var result = argReduce(backend, x, reduceType, output);
|
|
backend.disposeIntermediateTensorInfo(output);
|
|
return result;
|
|
}
|
|
function argReducePacked(backend, x, reduceType, bestIndicesA) {
|
|
if (bestIndicesA === void 0) { bestIndicesA = null; }
|
|
var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape;
|
|
var inSize = inShape[inShape.length - 1];
|
|
var windowSize = tf.backend_util.computeOptimalWindowSize(inSize);
|
|
var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null);
|
|
var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA];
|
|
var output = backend.runWebGLProgram(program, inputs, 'int32');
|
|
if (output.shape.length === x.shape.length) {
|
|
var result = argReducePacked(backend, x, reduceType, output);
|
|
backend.disposeIntermediateTensorInfo(output);
|
|
return result;
|
|
}
|
|
return output;
|
|
}
|
|
function argMinMaxReduce(backend, x, axis, reduceType) {
|
|
var axes = [axis];
|
|
tf.backend_util.assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.shape.length);
|
|
if (!tf.env().getBool('WEBGL_PACK_REDUCE') || x.shape.length <= 2) {
|
|
var intermediateTensorInfos = [];
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1];
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var a2D = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
intermediateTensorInfos.push(a2D);
|
|
var reduced = argReduce(backend, a2D, reduceType);
|
|
intermediateTensorInfos.push(reduced);
|
|
var reshaped = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } });
|
|
intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return reshaped;
|
|
}
|
|
return argReducePacked(backend, x, reduceType);
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function argMax(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis;
|
|
var axes = tf.util.parseAxisParam(axis, x.shape);
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length);
|
|
var $x = x;
|
|
var intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('argMax', [axes[0]], $x.shape.length);
|
|
var out = argMinMaxReduce(backend, $x, axes[0], 'max');
|
|
intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return out;
|
|
}
|
|
var argMaxConfig = {
|
|
kernelName: tf.ArgMax,
|
|
backendName: 'webgl',
|
|
kernelFunc: argMax
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function argMin(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis;
|
|
var axes = tf.util.parseAxisParam(axis, x.shape);
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, x.shape.length);
|
|
var $x = x;
|
|
var intermediateTensorInfos = [];
|
|
if (permutedAxes != null) {
|
|
$x = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
intermediateTensorInfos.push($x);
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, $x.shape.length);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('argMin', [axes[0]], $x.shape.length);
|
|
var out = argMinMaxReduce(backend, $x, axes[0], 'min');
|
|
intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return out;
|
|
}
|
|
var argMinConfig = {
|
|
kernelName: tf.ArgMin,
|
|
backendName: 'webgl',
|
|
kernelFunc: argMin
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ASIN = CHECK_NAN_SNIPPET + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n";
|
|
var asin = unaryKernelFunc({ opSnippet: ASIN });
|
|
var asinConfig = {
|
|
kernelName: tf.Asin,
|
|
backendName: 'webgl',
|
|
kernelFunc: asin,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ASINH = CHECK_NAN_SNIPPET + "return log(x + sqrt(x * x + 1.0));";
|
|
var asinh = unaryKernelFunc({ opSnippet: ASINH });
|
|
var asinhConfig = {
|
|
kernelName: tf.Asinh,
|
|
backendName: 'webgl',
|
|
kernelFunc: asinh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ATAN = CHECK_NAN_SNIPPET + "\n return atan(x);\n";
|
|
var atan = unaryKernelFunc({ opSnippet: ATAN });
|
|
var atanConfig = {
|
|
kernelName: tf.Atan,
|
|
backendName: 'webgl',
|
|
kernelFunc: atan,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ATAN2 = CHECK_NAN_SNIPPET_BINARY + "\n return atan(a, b);\n";
|
|
var ATAN2_PACKED = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
|
|
CHECK_NAN_SNIPPET_BINARY_PACKED + "\n return result;\n";
|
|
var atan2 = binaryKernelFunc({ opSnippet: ATAN2, packedOpSnippet: ATAN2_PACKED });
|
|
var atan2Config = {
|
|
kernelName: tf.Atan2,
|
|
backendName: 'webgl',
|
|
kernelFunc: atan2,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ATANH = CHECK_NAN_SNIPPET + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\nreturn (log(1.0 + x) - log(1.0 - x)) / 2.0;";
|
|
var atanh = unaryKernelFunc({ opSnippet: ATANH });
|
|
var atanhConfig = {
|
|
kernelName: tf.Atanh,
|
|
backendName: 'webgl',
|
|
kernelFunc: atanh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var Pool2DProgram = /** @class */ (function () {
|
|
function Pool2DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
|
|
if (flattenPositions === void 0) { flattenPositions = false; }
|
|
if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
|
|
this.variableNames = ['x'];
|
|
if (poolType === 'avg' && computePositions) {
|
|
throw new Error('Cannot compute positions for average pool.');
|
|
}
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
this.outputShape = convInfo.outShape;
|
|
var isAvgPool = poolType === 'avg';
|
|
var batchFlattenPositionStr = "((batch * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
|
|
var flattenPositionStr = "(xR * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + d";
|
|
var initializationValue = '0.0';
|
|
if (!isAvgPool) {
|
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
|
initializationValue = '-1.0 / 1e-20';
|
|
}
|
|
if (computePositions) {
|
|
var compareOp_1 = '>=';
|
|
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_1 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ? (includeBatchInIndex ? batchFlattenPositionStr :
|
|
flattenPositionStr) :
|
|
"wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
|
|
return;
|
|
}
|
|
var compareOp = 'max';
|
|
var returnValue = poolType + "(" + poolType + "(" + poolType + "(" +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (poolType === 'avg') {
|
|
returnValue = "avgValue / count";
|
|
}
|
|
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
|
var filterWidthVec4Remainder = filterWidth % 4;
|
|
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
|
|
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n ";
|
|
}
|
|
return Pool2DProgram;
|
|
}());
|
|
var Pool3DProgram = /** @class */ (function () {
|
|
function Pool3DProgram(convInfo, poolType, computePositions, flattenPositions, includeBatchInIndex) {
|
|
if (flattenPositions === void 0) { flattenPositions = false; }
|
|
if (includeBatchInIndex === void 0) { includeBatchInIndex = false; }
|
|
this.variableNames = ['x'];
|
|
if (poolType === 'avg' && computePositions) {
|
|
throw new Error('Cannot compute positions for average pool.');
|
|
}
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationDepth = convInfo.dilationDepth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padFront = convInfo.padInfo.front;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
this.outputShape = convInfo.outShape;
|
|
var isAvgPool = poolType === 'avg';
|
|
var initializationValue = '0.0';
|
|
if (!isAvgPool) {
|
|
// WebGL on Firefox Linux can't compile 1/0 so we do 1/eps.
|
|
initializationValue = '-1.0 / 1e-20';
|
|
}
|
|
if (computePositions) {
|
|
var compareOp_2 = '>=';
|
|
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = " + (flattenPositions ?
|
|
(includeBatchInIndex ?
|
|
"(((batch * " + convInfo.inDepth + " + xD) * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch" :
|
|
"((xD * " + convInfo.inHeight + " + xR) * " + convInfo.inWidth + " + xC) * " + convInfo.inChannels + " + ch") :
|
|
"wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC") + ";\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n ";
|
|
return;
|
|
}
|
|
var compareOp = 'max';
|
|
var returnValue = poolType + "(" + poolType + "(" + poolType + "(" +
|
|
'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])';
|
|
if (poolType === 'avg') {
|
|
returnValue = "avgValue / count";
|
|
}
|
|
var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4;
|
|
var filterWidthVec4Remainder = filterWidth % 4;
|
|
var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n ";
|
|
this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n ";
|
|
}
|
|
return Pool3DProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function avgPool(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
assertNotComplex(x, 'avgPool');
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
|
|
var dilations = 1;
|
|
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool: Either strides or dilations must be 1. ' +
|
|
("Got strides " + strides + " and dilations '" + dilations + "'"); });
|
|
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
return identity({ inputs: { x: x }, backend: backend });
|
|
}
|
|
var avgPoolProgram = new Pool2DProgram(convInfo, 'avg', false);
|
|
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
|
|
}
|
|
var avgPoolConfig = {
|
|
kernelName: tf.AvgPool,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function avgPool3D(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, dataFormat = attrs.dataFormat;
|
|
var dilations = [1, 1, 1];
|
|
var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
|
|
var avgPoolProgram = new Pool3DProgram(convInfo, 'avg', false);
|
|
return backend.runWebGLProgram(avgPoolProgram, [x], 'float32');
|
|
}
|
|
var avgPool3DConfig = {
|
|
kernelName: tf.AvgPool3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool3D
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var AvgPool2DBackpropProgram = /** @class */ (function () {
|
|
function AvgPool2DBackpropProgram(convInfo) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = convInfo.inShape;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
var avgMultiplier = 1 / (filterHeight * filterWidth);
|
|
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return AvgPool2DBackpropProgram;
|
|
}());
|
|
var AvgPool3DBackpropProgram = /** @class */ (function () {
|
|
function AvgPool3DBackpropProgram(convInfo) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = convInfo.inShape;
|
|
var filterDepth = convInfo.filterDepth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationDepth = convInfo.dilationDepth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth);
|
|
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return AvgPool3DBackpropProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function avgPool3DGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, input = inputs.input;
|
|
var x = input;
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
|
|
var dilations = [1, 1, 1];
|
|
var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
var avgPoolBackpropProgram = new AvgPool3DBackpropProgram(convInfo);
|
|
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
|
|
}
|
|
var avgPoolGrad3DConfig = {
|
|
kernelName: tf.AvgPool3DGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPool3DGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function avgPoolGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, input = inputs.input;
|
|
var x = input;
|
|
assertNotComplex([dy, input], 'avgPoolGrad');
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad;
|
|
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad);
|
|
var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo);
|
|
return backend.runWebGLProgram(avgPoolBackpropProgram, [dy], x.dtype);
|
|
}
|
|
var avgPoolGradConfig = {
|
|
kernelName: tf.AvgPoolGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: avgPoolGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function batchMatMul(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var a = inputs.a, b = inputs.b;
|
|
var transposeA = attrs.transposeA, transposeB = attrs.transposeB;
|
|
return batchMatMulImpl({ a: a, b: b, transposeA: transposeA, transposeB: transposeB, backend: backend });
|
|
}
|
|
var batchMatMulConfig = {
|
|
kernelName: tf.BatchMatMul,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchMatMul,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var BatchNormProgram = /** @class */ (function () {
|
|
function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
|
|
this.outputShape = [];
|
|
this.variableNames = ['x', 'mean', 'variance'];
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape);
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape);
|
|
var offsetSnippet = '0.0';
|
|
if (offsetShape != null) {
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape);
|
|
this.variableNames.push('offset');
|
|
offsetSnippet = 'getOffsetAtOutCoords()';
|
|
}
|
|
var scaleSnippet = '1.0';
|
|
if (scaleShape != null) {
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape);
|
|
this.variableNames.push('scale');
|
|
scaleSnippet = 'getScaleAtOutCoords()';
|
|
}
|
|
this.outputShape = xShape;
|
|
this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n ";
|
|
}
|
|
return BatchNormProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var BatchNormPackedProgram = /** @class */ (function () {
|
|
function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) {
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.variableNames = ['x', 'mean', 'variance'];
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, meanShape);
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, varianceShape);
|
|
var offsetSnippet = 'vec4(0.0)';
|
|
if (offsetShape != null) {
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, offsetShape);
|
|
this.variableNames.push('offset');
|
|
offsetSnippet = 'getOffsetAtOutCoords()';
|
|
}
|
|
var scaleSnippet = 'vec4(1.0)';
|
|
if (scaleShape != null) {
|
|
tf.backend_util.assertAndGetBroadcastShape(xShape, scaleShape);
|
|
this.variableNames.push('scale');
|
|
scaleSnippet = 'getScaleAtOutCoords()';
|
|
}
|
|
this.outputShape = xShape;
|
|
this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n ";
|
|
}
|
|
return BatchNormPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var batchNorm = function (_a) {
|
|
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
|
|
var x = inputs.x, mean = inputs.mean, variance = inputs.variance, offset = inputs.offset, scale = inputs.scale;
|
|
tf.util.assert(mean.shape.length === variance.shape.length, function () { return 'Batch normalization gradient requires mean and variance to have ' +
|
|
'equal ranks.'; });
|
|
tf.util.assert(offset == null || mean.shape.length === offset.shape.length, function () { return 'Batch normalization gradient requires mean and offset to have ' +
|
|
'equal ranks.'; });
|
|
tf.util.assert(scale == null || mean.shape.length === scale.shape.length, function () { return 'Batch normalization gradient requires mean and scale to have ' +
|
|
'equal ranks.'; });
|
|
var varianceEpsilon = attrs.varianceEpsilon;
|
|
if (varianceEpsilon == null) {
|
|
varianceEpsilon = 0.001;
|
|
}
|
|
var finalInputs = [x, mean, variance];
|
|
var offsetShape = null;
|
|
if (offset != null) {
|
|
offsetShape = offset.shape;
|
|
finalInputs.push(offset);
|
|
}
|
|
var scaleShape = null;
|
|
if (scale != null) {
|
|
scaleShape = scale.shape;
|
|
finalInputs.push(scale);
|
|
}
|
|
var program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ?
|
|
new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon) :
|
|
new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon);
|
|
var output = backend.runWebGLProgram(program, finalInputs, finalInputs[0].dtype);
|
|
return output;
|
|
};
|
|
var batchNormConfig = {
|
|
kernelName: tf.FusedBatchNorm,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchNorm,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SliceProgram = /** @class */ (function () {
|
|
function SliceProgram(destSize) {
|
|
this.variableNames = ['source'];
|
|
this.outputShape = destSize;
|
|
this.rank = destSize.length;
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var uniformPart = "uniform int start[" + this.rank + "];";
|
|
var sourceCoords = getCoords(this.rank);
|
|
var body;
|
|
var coordSum = destSize.map(function (_, i) {
|
|
return "sourceLoc." + coords[i] + " = start[" + i + "] + coords." + coords[i] + ";";
|
|
});
|
|
body = "\n " + dtype + " sourceLoc;\n " + dtype + " coords = getOutputCoords();\n " + coordSum.join('\n') + "\n ";
|
|
this.userCode = "\n " + uniformPart + "\n void main() {\n " + body + "\n setOutput(getSource(" + sourceCoords + "));\n }\n ";
|
|
}
|
|
SliceProgram.prototype.getCustomSetupFunc = function (start) {
|
|
var _this = this;
|
|
if (start.length !== this.rank) {
|
|
throw Error("The rank (" + this.rank + ") of the program must match the " +
|
|
("length of start (" + start.length + ")"));
|
|
}
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.startLoc == null) {
|
|
_this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');
|
|
if (_this.startLoc == null) {
|
|
// This means the compiler has optimized and realized it doesn't need
|
|
// the uniform.
|
|
return;
|
|
}
|
|
}
|
|
gpgpu.gl.uniform1iv(_this.startLoc, start);
|
|
};
|
|
};
|
|
return SliceProgram;
|
|
}());
|
|
var coords = ['x', 'y', 'z', 'w', 'u', 'v'];
|
|
function getCoords(rank) {
|
|
if (rank === 1) {
|
|
return 'sourceLoc';
|
|
}
|
|
else if (rank <= 6) {
|
|
return coords.slice(0, rank).map(function (x) { return 'sourceLoc.' + x; }).join(',');
|
|
}
|
|
else {
|
|
throw Error("Slicing for rank " + rank + " is not yet supported");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SlicePackedProgram = /** @class */ (function () {
|
|
function SlicePackedProgram(destSize) {
|
|
this.variableNames = ['source'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = destSize;
|
|
this.rank = destSize.length;
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var coords = getChannels('coords', this.rank);
|
|
var sourceLoc = getChannels('sourceLoc', this.rank);
|
|
var innerDims = this.rank === 1 ? 'sourceLoc' : "vec2(" + sourceLoc.slice(-2).join() + ")";
|
|
var getChannel = "getChannel(getSource(" + sourceLoc.join() + "), " + innerDims + ")";
|
|
var upperRow = "\n result.x = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.y = " + getChannel + ";\n --" + sourceLoc[this.rank - 1] + ";\n }\n ";
|
|
var lowerRow = this.rank === 1 ? '' : "\n --" + coords[this.rank - 1] + ";\n if (++" + coords[this.rank - 2] + " < " + destSize[this.rank - 2] + ") {\n ++" + sourceLoc[this.rank - 2] + ";\n result.z = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.w = " + getChannel + ";\n }\n }\n ";
|
|
var sourceLocSetup = this.rank <= 4 ?
|
|
"sourceLoc = coords +\n " + dtype + "(" + destSize.map(function (_, i) { return "start[" + i + "]"; }).join() + ");" :
|
|
destSize.map(function (_, i) { return sourceLoc[i] + " = " + coords[i] + " + start[" + i + "];"; })
|
|
.join('\n');
|
|
this.userCode = "\n uniform int start[" + this.rank + "];\n void main() {\n " + dtype + " coords = getOutputCoords();\n " + dtype + " sourceLoc;\n " + sourceLocSetup + "\n vec4 result = vec4(0.);\n " + upperRow + "\n " + lowerRow + "\n setOutput(result);\n }\n ";
|
|
}
|
|
SlicePackedProgram.prototype.getCustomSetupFunc = function (start) {
|
|
var _this = this;
|
|
if (start.length !== this.rank) {
|
|
throw Error("The rank (" + this.rank + ") of the program must match the " +
|
|
("length of start (" + start.length + ")"));
|
|
}
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.startLoc == null) {
|
|
_this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start');
|
|
if (_this.startLoc == null) {
|
|
// This means the compiler has optimized and realized it doesn't need
|
|
// the uniform.
|
|
return;
|
|
}
|
|
}
|
|
gpgpu.gl.uniform1iv(_this.startLoc, start);
|
|
};
|
|
};
|
|
return SlicePackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function shallowSlice(x, begin, size, backend) {
|
|
var xTexData = backend.texData.get(x.dataId);
|
|
var t = backend.makeTensorInfo(size, x.dtype);
|
|
var newTexData = backend.texData.get(t.dataId);
|
|
// Copy texture data from the original tensor.
|
|
Object.assign(newTexData, xTexData);
|
|
newTexData.refCount = 1;
|
|
newTexData.shape = size;
|
|
newTexData.dtype = x.dtype;
|
|
var flatOffset = tf.slice_util.computeFlatOffset(begin, tf.util.computeStrides(x.shape));
|
|
if (xTexData.slice) {
|
|
// We are slicing an already sliced tensor, so we have to accumulate
|
|
// the offset.
|
|
flatOffset += xTexData.slice.flatOffset;
|
|
}
|
|
newTexData.slice = {
|
|
flatOffset: flatOffset,
|
|
// Point to the original dataId, which is used to do ref counting.
|
|
origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId
|
|
};
|
|
// Increase the ref count for that data bucket.
|
|
var refCount = backend.dataRefCount.get(newTexData.slice.origDataId) || 1;
|
|
backend.dataRefCount.set(newTexData.slice.origDataId, refCount + 1);
|
|
return t;
|
|
}
|
|
function slice(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var begin = attrs.begin, size = attrs.size;
|
|
var _a = tf.slice_util.parseSliceParams(x, begin, size), $begin = _a[0], $size = _a[1];
|
|
tf.slice_util.assertParamsValid(x, $begin, $size);
|
|
if (tf.util.sizeFromShape($size) === 0) {
|
|
return backend.makeTensorInfo($size, x.dtype, []);
|
|
}
|
|
// Run on cpu if dtype is string. For string, the backend represents it
|
|
// as Uint8Array[], where each Uint8Array is a character. Given that the
|
|
// computation is only on the outer array, uploading the whole data onto
|
|
// gpu is wasteful. Also, currently webgl doesn't have a design to
|
|
// upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
|
|
// just run the kernel on cpu if dtype is string.
|
|
if (backend.shouldExecuteOnCPU([x]) || x.dtype === 'string') {
|
|
var xTexData = backend.texData.get(x.dataId);
|
|
var outValues = sliceImplCPU(xTexData.values, $begin, $size, x.shape, x.dtype);
|
|
return backend.makeTensorInfo($size, x.dtype, outValues);
|
|
}
|
|
var isPacked = backend.texData.get(x.dataId).isPacked;
|
|
var isContinous = tf.slice_util.isSliceContinous(x.shape, $begin, $size);
|
|
if (isPacked || !isContinous) {
|
|
var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new SlicePackedProgram($size) :
|
|
new SliceProgram($size);
|
|
var customSetup = program.getCustomSetupFunc($begin);
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customSetup);
|
|
}
|
|
backend.uploadToGPU(x.dataId);
|
|
return shallowSlice(x, $begin, $size, backend);
|
|
}
|
|
var sliceConfig = {
|
|
kernelName: tf.Slice,
|
|
backendName: 'webgl',
|
|
kernelFunc: slice
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var batchToSpaceND = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var blockShape = attrs.blockShape, crops = attrs.crops;
|
|
tf.util.assert(x.shape.length <= 4, function () { return 'batchToSpaceND for rank > 4 with a WebGL backend not ' +
|
|
'implemented yet'; });
|
|
var prod = blockShape.reduce(function (a, b) { return a * b; });
|
|
var reshaped = tf.backend_util.getReshaped(x.shape, blockShape, prod);
|
|
var permuted = tf.backend_util.getPermuted(reshaped.length, blockShape.length);
|
|
var reshapedPermuted = tf.backend_util.getReshapedPermuted(x.shape, blockShape, prod);
|
|
var sliceBeginCoords = tf.backend_util.getSliceBeginCoords(crops, blockShape.length);
|
|
var sliceSize = tf.backend_util.getSliceSize(reshapedPermuted, crops, blockShape.length);
|
|
var toDispose = [];
|
|
var reshapedIntermediate = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: reshaped } });
|
|
var transposedIntermediate = transpose({ inputs: { x: reshapedIntermediate }, backend: backend, attrs: { perm: permuted } });
|
|
var reshapedIntermediate2 = reshape({
|
|
inputs: { x: transposedIntermediate },
|
|
backend: backend,
|
|
attrs: { shape: reshapedPermuted }
|
|
});
|
|
var sliced = slice({
|
|
inputs: { x: reshapedIntermediate2 },
|
|
backend: backend,
|
|
attrs: { begin: sliceBeginCoords, size: sliceSize }
|
|
});
|
|
toDispose.push(reshapedIntermediate);
|
|
toDispose.push(transposedIntermediate);
|
|
toDispose.push(reshapedIntermediate2);
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return sliced;
|
|
};
|
|
var batchToSpaceNDConfig = {
|
|
kernelName: tf.BatchToSpaceND,
|
|
backendName: 'webgl',
|
|
kernelFunc: batchToSpaceND
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function bincount(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, weights = inputs.weights;
|
|
var size = attrs.size;
|
|
var xVals = backend.readSync(x.dataId);
|
|
var weightsVals = backend.readSync(weights.dataId);
|
|
var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
var bincountConfig = {
|
|
kernelName: tf.Bincount,
|
|
backendName: 'webgl',
|
|
kernelFunc: bincount
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var NOT_EQUAL = "return float(a != b);";
|
|
var notEqual = binaryKernelFunc({ opSnippet: NOT_EQUAL, cpuKernelImpl: notEqualImplCPU, dtype: 'bool' });
|
|
var notEqualConfig = {
|
|
kernelName: tf.NotEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: notEqual,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function real(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var input = inputs.input;
|
|
var inputData = backend.texData.get(input.dataId);
|
|
return identity({ inputs: { x: inputData.complexTensorInfos.real }, backend: backend });
|
|
}
|
|
var realConfig = {
|
|
kernelName: tf.Real,
|
|
backendName: 'webgl',
|
|
kernelFunc: real
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TO_INT = "return float(int(x));";
|
|
function int(input, backend) {
|
|
var program = new UnaryOpProgram(input.shape, TO_INT);
|
|
var output = backend.runWebGLProgram(program, [input], 'int32');
|
|
return { dataId: output.dataId, shape: output.shape, dtype: output.dtype };
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function cast(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var dtype = attrs.dtype;
|
|
// Casting to complex64.
|
|
if (dtype === 'complex64') {
|
|
if (x.dtype === 'complex64') {
|
|
return identity({ inputs: { x: x }, backend: backend });
|
|
}
|
|
// TODO(annxingyuan): Import kernel function once zeros is modularized.
|
|
var zerosTensor = tf.zeros(x.shape);
|
|
var floatX = cast({ inputs: { x: x }, backend: backend, attrs: { dtype: 'float32' } });
|
|
var result = complex({ inputs: { real: floatX, imag: zerosTensor }, backend: backend });
|
|
zerosTensor.dispose();
|
|
backend.disposeIntermediateTensorInfo(floatX);
|
|
return result;
|
|
}
|
|
// Casting from complex64
|
|
if (x.dtype === 'complex64') {
|
|
var realPart = real({ inputs: { input: x }, backend: backend });
|
|
var result = cast({ inputs: { x: realPart }, backend: backend, attrs: { dtype: dtype } });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
return result;
|
|
}
|
|
if (!tf.util.hasEncodingLoss(x.dtype, dtype)) {
|
|
// We don't change the underlying data, since we cast to higher
|
|
// precision.
|
|
var result = identity({ inputs: { x: x }, backend: backend });
|
|
return { dataId: result.dataId, shape: result.shape, dtype: dtype };
|
|
}
|
|
if (dtype === 'int32') {
|
|
return int(x, backend);
|
|
}
|
|
if (dtype === 'bool') {
|
|
var zerosTensorInfo = backend.makeTensorInfo([], 'bool', tf.util.getTypedArrayFromDType('bool', 1));
|
|
var binaryInputs = { a: x, b: zerosTensorInfo };
|
|
var result = notEqual({ inputs: binaryInputs, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(zerosTensorInfo);
|
|
return result;
|
|
}
|
|
throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype);
|
|
}
|
|
var castConfig = {
|
|
kernelName: tf.Cast,
|
|
backendName: 'webgl',
|
|
kernelFunc: cast
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var CEIL = "return ceil(x);";
|
|
var ceil = unaryKernelFunc({ opSnippet: CEIL, packedOpSnippet: CEIL, cpuKernelImpl: ceilImplCPU });
|
|
var ceilConfig = {
|
|
kernelName: tf.Ceil,
|
|
backendName: 'webgl',
|
|
kernelFunc: ceil
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ClipProgram = /** @class */ (function () {
|
|
function ClipProgram(aShape) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = aShape;
|
|
this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n ";
|
|
}
|
|
ClipProgram.prototype.getCustomSetupFunc = function (min, max) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.minLoc == null) {
|
|
_this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
|
|
_this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.minLoc, min);
|
|
gpgpu.gl.uniform1f(_this.maxLoc, max);
|
|
};
|
|
};
|
|
return ClipProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ClipPackedProgram = /** @class */ (function () {
|
|
function ClipPackedProgram(aShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = aShape;
|
|
this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n ";
|
|
}
|
|
ClipPackedProgram.prototype.getCustomSetupFunc = function (min, max) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.minLoc == null) {
|
|
_this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal');
|
|
_this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.minLoc, min);
|
|
gpgpu.gl.uniform1f(_this.maxLoc, max);
|
|
};
|
|
};
|
|
return ClipPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function clipByValue(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var clipValueMin = attrs.clipValueMin, clipValueMax = attrs.clipValueMax;
|
|
var program;
|
|
if (tf.env().getBool('WEBGL_PACK_CLIP')) {
|
|
program = new ClipPackedProgram(x.shape);
|
|
}
|
|
else {
|
|
program = new ClipProgram(x.shape);
|
|
}
|
|
var customSetup = program.getCustomSetupFunc(clipValueMin, clipValueMax);
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customSetup);
|
|
}
|
|
var clipByValueConfig = {
|
|
kernelName: tf.ClipByValue,
|
|
backendName: 'webgl',
|
|
kernelFunc: clipByValue
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ComplexAbsProgram = /** @class */ (function () {
|
|
function ComplexAbsProgram(shape) {
|
|
this.variableNames = ['real', 'imag'];
|
|
this.outputShape = shape;
|
|
this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n ";
|
|
}
|
|
return ComplexAbsProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// Returns a TensorInfo with the complex shape and the dataId of the
|
|
// underlying part. We need to do this because a reshaped complex tensor is
|
|
// not reflected in its parts.
|
|
function makeComplexComponentTensorInfo(complexTensor, complexPart) {
|
|
return {
|
|
dataId: complexPart.dataId,
|
|
dtype: complexPart.dtype,
|
|
shape: complexTensor.shape
|
|
};
|
|
}
|
|
function complexAbs(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
var xData = backend.texData.get(x.dataId);
|
|
var program = new ComplexAbsProgram(x.shape);
|
|
var programInputs = [
|
|
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.real),
|
|
makeComplexComponentTensorInfo(x, xData.complexTensorInfos.imag),
|
|
];
|
|
return backend.runWebGLProgram(program, programInputs, programInputs[0].dtype);
|
|
}
|
|
var complexAbsConfig = {
|
|
kernelName: tf.ComplexAbs,
|
|
backendName: 'webgl',
|
|
kernelFunc: complexAbs
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ConcatProgram = /** @class */ (function () {
|
|
// Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat().
|
|
function ConcatProgram(shapes) {
|
|
this.outputShape = [];
|
|
this.outputShape = tf.backend_util.computeOutShape(shapes, 1 /* axis */);
|
|
this.variableNames = shapes.map(function (_, i) { return "T" + i; });
|
|
var offsets = new Array(shapes.length - 1);
|
|
offsets[0] = shapes[0][1];
|
|
for (var i = 1; i < offsets.length; i++) {
|
|
offsets[i] = offsets[i - 1] + shapes[i][1];
|
|
}
|
|
var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"];
|
|
for (var i = 1; i < offsets.length; i++) {
|
|
var shift = offsets[i - 1];
|
|
snippets.push("else if (yC < " + offsets[i] + ") " +
|
|
("setOutput(getT" + i + "(yR, yC-" + shift + "));"));
|
|
}
|
|
var lastIndex = offsets.length;
|
|
var lastShift = offsets[offsets.length - 1];
|
|
snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));");
|
|
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join('\n ') + "\n }\n ";
|
|
}
|
|
return ConcatProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ConcatPackedProgram = /** @class */ (function () {
|
|
function ConcatPackedProgram(shapes, axis) {
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
this.outputShape = tf.backend_util.computeOutShape(shapes, axis);
|
|
var shape = this.outputShape;
|
|
var rank = shape.length;
|
|
var dtype = getCoordsDataType(rank);
|
|
var coords = getChannels('coords', rank);
|
|
var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank);
|
|
this.variableNames = shapes.map(function (_, i) { return "T" + i; });
|
|
var offsets = new Array(shapes.length - 1);
|
|
offsets[0] = shapes[0][axis];
|
|
for (var i = 1; i < offsets.length; i++) {
|
|
offsets[i] = offsets[i - 1] + shapes[i][axis];
|
|
}
|
|
var channel = channels[axis];
|
|
var lastChannels = channels.slice(-2);
|
|
var allChannels = channels.join();
|
|
var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }";
|
|
for (var i = 1; i < offsets.length; i++) {
|
|
var shift_1 = offsets[i - 1];
|
|
// Note: the >= comparison below may seem unnecessary given the check
|
|
// above but is needed to workaround branch execution issues on some
|
|
// devices. It makes all the conditions exclusive without relying on
|
|
// execution order.
|
|
getValueSnippet += "\n if (" + channel + " < " + offsets[i] + " && " + channel + " >= " + offsets[i - 1] + ") {\n return getChannel(\n getT" + i + "(" + shiftedChannels(channels, channel, shift_1) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift_1) + "));\n }";
|
|
}
|
|
var lastIndex = offsets.length;
|
|
var shift = offsets[offsets.length - 1];
|
|
getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));";
|
|
this.userCode = "\n float getValue(" + channels.map(function (x) { return 'int ' + x; }) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords + "), 0., 0., 0.);\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " + 1;\n if (" + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords + ");\n }\n\n " + coords[rank - 2] + " = " + coords[rank - 2] + " + 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords + ");\n }\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " - 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords + ");\n }\n setOutput(result);\n }\n ";
|
|
}
|
|
return ConcatPackedProgram;
|
|
}());
|
|
/**
|
|
* Return an expression for coordinates into a vector where a given channel
|
|
* will be offset by [shift].
|
|
*
|
|
* @param channels the channels to consider
|
|
* @param channel the channel we want shifted
|
|
* @param shift the amount to subtract from the channel.
|
|
*
|
|
* @returns a string of the form 'x, y-[shift], z' where any one channel can
|
|
* have the shift applied.
|
|
*/
|
|
function shiftedChannels(channels, channel, shift) {
|
|
var channelIdx = channels.indexOf(channel);
|
|
var res = channels.map(function (c, idx) {
|
|
if (idx === channelIdx) {
|
|
return c + " - " + shift;
|
|
}
|
|
else {
|
|
return c;
|
|
}
|
|
});
|
|
return res.join();
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function imag(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var input = inputs.input;
|
|
var inputData = backend.texData.get(input.dataId);
|
|
return identity({ inputs: { x: inputData.complexTensorInfos.imag }, backend: backend });
|
|
}
|
|
var imagConfig = {
|
|
kernelName: tf.Imag,
|
|
backendName: 'webgl',
|
|
kernelFunc: imag
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function concatImpl$1(inputs, axis, backend) {
|
|
var dtype = inputs[0].dtype;
|
|
if (dtype === 'complex64') {
|
|
var reals = inputs.map(function (t) { return real({ inputs: { input: t }, backend: backend }); });
|
|
var imags = inputs.map(function (t) { return imag({ inputs: { input: t }, backend: backend }); });
|
|
var realConcated = concatImpl$1(reals, axis, backend);
|
|
var imagConcated = concatImpl$1(imags, axis, backend);
|
|
var result_1 = complex({ inputs: { real: realConcated, imag: imagConcated }, backend: backend });
|
|
reals.forEach(function (r) { return backend.disposeIntermediateTensorInfo(r); });
|
|
imags.forEach(function (i) { return backend.disposeIntermediateTensorInfo(i); });
|
|
backend.disposeIntermediateTensorInfo(realConcated);
|
|
backend.disposeIntermediateTensorInfo(imagConcated);
|
|
return result_1;
|
|
}
|
|
var runOnCpu = backend.shouldExecuteOnCPU(inputs);
|
|
// Run on cpu if dtype is string. For string, the backend represents it
|
|
// as Uint8Array[], where each Uint8Array is a character. Given that the
|
|
// computation is only on the outer array, uploading the whole data onto
|
|
// gpu is wasteful. Also, currently webgl doesn't have a design to
|
|
// upload and retrieve Uint8Array[] between cpu and gpu. Therefore, we
|
|
// just run the kernel on cpu if dtype is string.
|
|
if (dtype === 'string') {
|
|
runOnCpu = true;
|
|
}
|
|
if (runOnCpu) {
|
|
// Any concat of n-dimensional tensors across any axis can be reduced to
|
|
// a concatenation of two-dimensional tensors across the axis 1 by first
|
|
// partitioning the axes of the original tensors into those less than the
|
|
// axis to be concatenated and the rest. Then reshape the tensors
|
|
// into a two-dimensional tensor by collapsing these two sets of axes and
|
|
// concatenate the resulting matrices across the axis 1, finally reshaping
|
|
// the result to have the proper shape.
|
|
var tensors2D_1 = inputs.map(function (t) {
|
|
var innerSize = tf.util.sizeFromShape(t.shape.slice(axis));
|
|
var shape = [-1, innerSize];
|
|
return reshape({ inputs: { x: t }, backend: backend, attrs: { shape: shape } });
|
|
});
|
|
var inputsValShapes = tensors2D_1.map(function (t) {
|
|
return { vals: backend.readSync(t.dataId), shape: t.shape };
|
|
});
|
|
// Concats 2d tensors along axis=1.
|
|
var outShape_1 = tf.backend_util.computeOutShape(tensors2D_1.map(function (t) { return t.shape; }), 1 /* axis */);
|
|
var simplyConcat = tensors2D_1[0].shape[0] === 1;
|
|
var outVals = concatImplCPU(inputsValShapes, outShape_1, dtype, simplyConcat);
|
|
var finalOutShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), axis);
|
|
var outInfo = backend.makeTensorInfo(finalOutShape, dtype, outVals);
|
|
tensors2D_1.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return outInfo;
|
|
}
|
|
if (inputs.length > tf.env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) {
|
|
var midIndex = Math.floor(inputs.length / 2);
|
|
var leftSide = concatImpl$1(inputs.slice(0, midIndex), axis, backend);
|
|
var rightSide = concatImpl$1(inputs.slice(midIndex), axis, backend);
|
|
var result_2 = concatImpl$1([leftSide, rightSide], axis, backend);
|
|
backend.disposeIntermediateTensorInfo(leftSide);
|
|
backend.disposeIntermediateTensorInfo(rightSide);
|
|
return result_2;
|
|
}
|
|
if (tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') &&
|
|
inputs[0].shape.length > 1) {
|
|
var program_1 = new ConcatPackedProgram(inputs.map(function (t) { return t.shape; }), axis);
|
|
return backend.runWebGLProgram(program_1, inputs, dtype);
|
|
}
|
|
var _a = computeTensors2D(inputs, axis, backend), tensors2D = _a.tensors2D, outShape = _a.outShape;
|
|
var program = new ConcatProgram(tensors2D.map(function (t) { return t.shape; }));
|
|
var result = backend.runWebGLProgram(program, tensors2D, dtype);
|
|
tensors2D.forEach(function (r) { return backend.disposeIntermediateTensorInfo(r); });
|
|
var reshapedResult = reshape({ inputs: { x: result }, attrs: { shape: outShape }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return reshapedResult;
|
|
}
|
|
function computeTensors2D(inputs, axis, backend) {
|
|
// Any concat of n-dimensional tensors across any axis can be reduced to
|
|
// a concatenation of two-dimensional tensors across the axis 1 by first
|
|
// partitioning the axes of the original tensors into those less than the
|
|
// axis to be concatenated and the rest. Then reshape the tensors
|
|
// into a two-dimensional tensor by collapsing these two sets of axes and
|
|
// concatenate the resulting matrices across the axis 1, finally reshaping
|
|
// the result to have the proper shape.
|
|
var outShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), axis);
|
|
var tensors2D = inputs.map(function (x) { return reshape({
|
|
inputs: { x: x },
|
|
attrs: { shape: [-1, tf.util.sizeFromShape(x.shape.slice(axis))] },
|
|
backend: backend
|
|
}); });
|
|
return { tensors2D: tensors2D, outShape: outShape };
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function concat(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var axis = attrs.axis;
|
|
var $axis = tf.util.parseAxisParam(axis, inputs[0].shape)[0];
|
|
var outShape = tf.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), $axis);
|
|
if (tf.util.sizeFromShape(outShape) === 0) {
|
|
return backend.makeTensorInfo(outShape, inputs[0].dtype, []);
|
|
}
|
|
// Keep only non-empty tensors (ignore tensors with 0 in their shape).
|
|
var $inputs = inputs.filter(function (t) { return tf.util.sizeFromShape(t.shape) > 0; });
|
|
if ($inputs.length === 1) {
|
|
return identity({ inputs: { x: $inputs[0] }, backend: backend });
|
|
}
|
|
var shapes = $inputs.map(function (t) { return t.shape; });
|
|
tf.backend_util.assertParamsConsistent(shapes, $axis);
|
|
return concatImpl$1($inputs, $axis, backend);
|
|
}
|
|
var concatConfig = {
|
|
kernelName: tf.Concat,
|
|
backendName: 'webgl',
|
|
kernelFunc: concat
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var Conv2DProgram = /** @class */ (function () {
|
|
function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights, hasLeakyreluAlpha) {
|
|
if (addBias === void 0) { addBias = false; }
|
|
if (activation === void 0) { activation = null; }
|
|
if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; }
|
|
if (hasLeakyreluAlpha === void 0) { hasLeakyreluAlpha = false; }
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
|
|
var inputDepthVec4Remainder = convInfo.inChannels % 4;
|
|
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
var rowDim = isChannelsLast ? 1 : 2;
|
|
var colDim = isChannelsLast ? 2 : 3;
|
|
var channelDim = isChannelsLast ? 3 : 1;
|
|
var activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivationWeights) {
|
|
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else if (hasLeakyreluAlpha) {
|
|
activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else {
|
|
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
|
|
}
|
|
applyActivationSnippet = "result = activation(result);";
|
|
}
|
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
|
|
}
|
|
return Conv2DProgram;
|
|
}());
|
|
var Conv3DProgram = /** @class */ (function () {
|
|
function Conv3DProgram(convInfo) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
var padFront = convInfo.padInfo.front;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationDepth = convInfo.dilationDepth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var filterDepth = convInfo.filterDepth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4;
|
|
var inputDepthVec4Remainder = convInfo.inChannels % 4;
|
|
this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return Conv3DProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var Im2ColPackedProgram = /** @class */ (function () {
|
|
function Im2ColPackedProgram(outputShape, inputShape, convInfo) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = outputShape;
|
|
var filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, strideWidth = convInfo.strideWidth, strideHeight = convInfo.strideHeight, padInfo = convInfo.padInfo, outWidth = convInfo.outWidth, dilationWidth = convInfo.dilationWidth, dilationHeight = convInfo.dilationHeight, dataFormat = convInfo.dataFormat;
|
|
var left = padInfo.left, top = padInfo.top;
|
|
var itemsPerBlockRow = inChannels * filterWidth;
|
|
var glsl = getGlslDifferences();
|
|
var isChannelsLast = dataFormat === 'channelsLast';
|
|
var rowDim = isChannelsLast ? 0 : 1;
|
|
var colDim = isChannelsLast ? 1 : 2;
|
|
var unrolled = "";
|
|
for (var row = 0; row <= 1; row++) {
|
|
for (var col = 0; col <= 1; col++) {
|
|
unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {\n offsetY = int(blockIndex / (" + outWidth + ")) * " + strideHeight + " - " + top + ";\n d0 = offsetY + " + dilationHeight + " * (pos / " + itemsPerBlockRow + ");\n\n if(d0 < " + inputShape[rowDim] + " && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), " + outWidth + ".) * " + strideWidth + ". - " + left + ".);\n d1 = offsetX + " + dilationWidth + " * (int(mod(float(pos), " + itemsPerBlockRow + ".) / " + inChannels + ".));\n\n if(d1 < " + inputShape[colDim] + " && d1 >= 0) {\n\n ch = int(mod(float(pos), " + inChannels + ".));\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n ";
|
|
}
|
|
}
|
|
this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n ";
|
|
}
|
|
return Im2ColPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// For 1x1 kernels that iterate through every point in the input, convolution
|
|
// can be expressed as matrix multiplication (without need for memory
|
|
// remapping).
|
|
function conv2dByMatMul(_a) {
|
|
var x = _a.x, filter = _a.filter, convInfo = _a.convInfo, backend = _a.backend, _b = _a.bias, bias = _b === void 0 ? null : _b, _c = _a.preluActivationWeights, preluActivationWeights = _c === void 0 ? null : _c, _d = _a.leakyreluAlpha, leakyreluAlpha = _d === void 0 ? 0 : _d, _e = _a.activation, activation = _e === void 0 ? null : _e;
|
|
// Reshapes conv2D input to 2D tensors, uses matMul and then reshape the
|
|
// result from 2D to 4D.
|
|
var xShape = x.shape;
|
|
var xTexData = backend.texData.get(x.dataId);
|
|
var sharedMatMulDim = convInfo.inChannels;
|
|
var outerShapeX = xShape[0] * xShape[1] * xShape[2];
|
|
var outerShapeFilter = convInfo.outChannels;
|
|
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
var transposeA = false;
|
|
var transposeB = false;
|
|
var out;
|
|
var intermediates = [];
|
|
// TODO: Once reduction ops are packed, batchMatMul will always be packed
|
|
// and we can remove this condition.
|
|
var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) &&
|
|
sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD;
|
|
var reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked;
|
|
if (batchMatMulWillBeUnpacked || !tf.env().getBool('WEBGL_LAZILY_UNPACK') ||
|
|
!tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ||
|
|
!reshapeWillBeExpensive) {
|
|
var targetShape = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] :
|
|
xShape[0] * xShape[2] * xShape[3];
|
|
var xReshaped = reshape({
|
|
inputs: { x: x },
|
|
backend: backend,
|
|
attrs: { shape: [1, targetShape, convInfo.inChannels] }
|
|
});
|
|
var filterReshaped = reshape({
|
|
inputs: { x: filter },
|
|
backend: backend,
|
|
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
|
|
});
|
|
var result = batchMatMulImpl({
|
|
a: xReshaped,
|
|
b: filterReshaped,
|
|
transposeA: transposeA,
|
|
transposeB: transposeB,
|
|
backend: backend,
|
|
bias: bias,
|
|
activation: activation,
|
|
preluActivationWeights: preluActivationWeights,
|
|
leakyreluAlpha: leakyreluAlpha
|
|
});
|
|
out = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: convInfo.outShape } });
|
|
intermediates.push(xReshaped);
|
|
intermediates.push(filterReshaped);
|
|
intermediates.push(result);
|
|
}
|
|
else {
|
|
// Following optimization is specific to packed |x| with odd row count
|
|
// (For example, in channelLast mode, 'row count' refers to x.shape[2]):
|
|
// we avoid expensive packed 2x2 reshape by padding row count to next,
|
|
// even number. When x.shape[2] is odd, the result of packed batchMatMul is
|
|
// the same (has the same texture layout and and values in the texture) as
|
|
// it is for even x.shape[2] + 1. We make the odd-rows tensor to look like
|
|
// even-rows tensor before the operation and, after the batchMatMul,
|
|
// fix the even-rows result to have odd number of rows.
|
|
var targetShape = isChannelsLast ?
|
|
xShape[0] * xShape[1] * (xShape[2] + 1) :
|
|
xShape[0] * xShape[2] * (xShape[3] + 1);
|
|
var xReshaped_1 = {
|
|
dataId: x.dataId,
|
|
shape: [1, targetShape, convInfo.inChannels],
|
|
dtype: x.dtype
|
|
};
|
|
// xTexData.shape gets referenced from GPGPUBinary.inShapeInfos.
|
|
// Decrementing row count, after batchMatMul->...->compileProgram leads to
|
|
// invalid row count within the reference in GPGPUBinary.inShapeInfos.
|
|
// Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos
|
|
// in compileProgram method, but that would affect compilation of all
|
|
// programs - instead, provide a copy here, with even row count, before
|
|
// calling batchMatMul->...->compileProgram and after that, the original
|
|
// xTexData.shape is restored.
|
|
var originalXTexDataShape = xTexData.shape;
|
|
xTexData.shape = xTexData.shape.slice();
|
|
xTexData.shape[xTexData.shape.length - 2]++;
|
|
tf.util.assert(isReshapeFree(xTexData.shape, xReshaped_1.shape), function () { return "packed reshape " + xTexData.shape + " to " + xReshaped_1.shape + " isn't free"; });
|
|
var filterReshaped = reshape({
|
|
inputs: { x: filter },
|
|
backend: backend,
|
|
attrs: { shape: [1, convInfo.inChannels, convInfo.outChannels] }
|
|
});
|
|
intermediates.push(filterReshaped);
|
|
var pointwiseConv = batchMatMulImpl({
|
|
a: xReshaped_1,
|
|
b: filterReshaped,
|
|
backend: backend,
|
|
transposeA: transposeA,
|
|
transposeB: transposeB,
|
|
bias: bias,
|
|
activation: activation,
|
|
preluActivationWeights: preluActivationWeights,
|
|
leakyreluAlpha: leakyreluAlpha
|
|
});
|
|
var pointwiseConvTexData = backend.texData.get(pointwiseConv.dataId);
|
|
tf.util.assert(pointwiseConvTexData.isPacked, function () { return 'batchMatMul result is expected to be packed'; });
|
|
// Restore the input shape to original.
|
|
xTexData.shape = originalXTexDataShape;
|
|
// Set the output shape - there is no need for expensive reshape as data
|
|
// layout is already correct.
|
|
pointwiseConvTexData.shape = convInfo.outShape;
|
|
out = identity({ inputs: { x: pointwiseConv }, backend: backend });
|
|
out.shape = convInfo.outShape;
|
|
intermediates.push(pointwiseConv);
|
|
}
|
|
for (var _i = 0, intermediates_1 = intermediates; _i < intermediates_1.length; _i++) {
|
|
var i = intermediates_1[_i];
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
// Implements the im2row algorithm as outlined in "High Performance
|
|
// Convolutional Neural Networks for Document Processing" (Suvisoft, 2006)
|
|
function conv2dWithIm2Row(_a) {
|
|
var x = _a.x, filter = _a.filter, convInfo = _a.convInfo, backend = _a.backend, _b = _a.bias, bias = _b === void 0 ? null : _b, _c = _a.preluActivationWeights, preluActivationWeights = _c === void 0 ? null : _c, _d = _a.leakyreluAlpha, leakyreluAlpha = _d === void 0 ? 0 : _d, _e = _a.activation, activation = _e === void 0 ? null : _e;
|
|
// Rearranges conv2d input so each block to be convolved over forms the
|
|
// column of a new matrix with shape [filterWidth * filterHeight *
|
|
// inChannels, outHeight * outWidth]. The filter is also rearranged so each
|
|
// output channel forms a row of a new matrix with shape [outChannels,
|
|
// filterWidth * filterHeight * inChannels]. The convolution is then
|
|
// computed by multiplying these matrices and reshaping the result.
|
|
var filterWidth = convInfo.filterWidth, filterHeight = convInfo.filterHeight, inChannels = convInfo.inChannels, outWidth = convInfo.outWidth, outHeight = convInfo.outHeight, dataFormat = convInfo.dataFormat;
|
|
var isChannelsLast = dataFormat === 'channelsLast';
|
|
var sharedDim = filterWidth * filterHeight * inChannels;
|
|
var numCols = outHeight * outWidth;
|
|
var x2ColShape = [sharedDim, numCols];
|
|
var transposeA = true;
|
|
var transposeB = false;
|
|
var intermediates = [];
|
|
var xSqueezed = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: x.shape.slice(1) } });
|
|
var w2Row = reshape({
|
|
inputs: { x: filter },
|
|
backend: backend,
|
|
attrs: { shape: [1, sharedDim, tf.util.sizeFromShape(filter.shape) / sharedDim] }
|
|
});
|
|
intermediates.push(xSqueezed);
|
|
intermediates.push(w2Row);
|
|
var im2ColProgram = new Im2ColPackedProgram(x2ColShape, xSqueezed.shape, convInfo);
|
|
var im2Col = backend.runWebGLProgram(im2ColProgram, [xSqueezed], 'float32');
|
|
var im2ColReshaped = reshape({
|
|
inputs: { x: im2Col },
|
|
backend: backend,
|
|
attrs: { shape: [1, x2ColShape[0], x2ColShape[1]] }
|
|
});
|
|
intermediates.push(im2Col);
|
|
intermediates.push(im2ColReshaped);
|
|
var hasBias = bias != null;
|
|
var hasPreluActivationWeights = preluActivationWeights != null;
|
|
var hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null;
|
|
var matmulProgram = new MatMulPackedProgram(im2ColReshaped.shape, w2Row.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
var inputs = [im2ColReshaped, w2Row];
|
|
if (bias) {
|
|
inputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
inputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
var product = backend.runWebGLProgram(matmulProgram, inputs, 'float32');
|
|
var outShape = isChannelsLast ?
|
|
[1, outHeight, outWidth, convInfo.outChannels] :
|
|
[1, convInfo.outChannels, outHeight, outWidth];
|
|
var out = reshape({ inputs: { x: product }, backend: backend, attrs: { shape: outShape } });
|
|
intermediates.push(product);
|
|
for (var _i = 0, intermediates_2 = intermediates; _i < intermediates_2.length; _i++) {
|
|
var i = intermediates_2[_i];
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv2d(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter;
|
|
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode;
|
|
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
|
var out;
|
|
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
|
|
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
|
|
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
|
|
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
|
|
out = conv2dByMatMul({ x: x, filter: filter, convInfo: convInfo, backend: backend });
|
|
}
|
|
else if (tf.env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
|
|
out = conv2dWithIm2Row({ x: x, filter: filter, convInfo: convInfo, backend: backend });
|
|
}
|
|
else {
|
|
var program = new Conv2DProgram(convInfo);
|
|
out = backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
}
|
|
var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } });
|
|
backend.disposeIntermediateTensorInfo(out);
|
|
return outReshaped;
|
|
}
|
|
var conv2DConfig = {
|
|
kernelName: tf.Conv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2d,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var Conv2DDerFilterProgram = /** @class */ (function () {
|
|
function Conv2DDerFilterProgram(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return Conv2DDerFilterProgram;
|
|
}());
|
|
var Conv2DDerInputProgram = /** @class */ (function () {
|
|
function Conv2DDerInputProgram(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var isChannelsLast = convInfo.dataFormat === 'channelsLast';
|
|
var padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
var rowDim = isChannelsLast ? 1 : 2;
|
|
var colDim = isChannelsLast ? 2 : 3;
|
|
var channelDim = isChannelsLast ? 3 : 1;
|
|
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return Conv2DDerInputProgram;
|
|
}());
|
|
var Conv3DDerFilterProgram = /** @class */ (function () {
|
|
function Conv3DDerFilterProgram(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var padFront = convInfo.padInfo.front;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return Conv3DDerFilterProgram;
|
|
}());
|
|
var Conv3DDerInputProgram = /** @class */ (function () {
|
|
function Conv3DDerInputProgram(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
var filterDepth = convInfo.filterDepth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var padFront = filterDepth - 1 - convInfo.padInfo.front;
|
|
var padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return Conv3DDerInputProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv2DBackpropFilter(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, dy = inputs.dy;
|
|
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape;
|
|
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
|
var program = new Conv2DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
var conv2DBackpropFilterConfig = {
|
|
kernelName: tf.Conv2DBackpropFilter,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2DBackpropFilter,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv2DBackpropInput(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, filter = inputs.filter;
|
|
var inputShape = attrs.inputShape, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode;
|
|
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
|
|
var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad, dimRoundingMode, false, $dataFormat);
|
|
var program = new Conv2DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
var conv2DBackpropInputConfig = {
|
|
kernelName: tf.Conv2DBackpropInput,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv2DBackpropInput,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv3D(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter;
|
|
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations;
|
|
var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filter.shape, strides, dilations, pad);
|
|
var program = new Conv3DProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
}
|
|
var conv3DConfig = {
|
|
kernelName: tf.Conv3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3D,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv3DBackpropFilterV2(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, dy = inputs.dy;
|
|
var strides = attrs.strides, pad = attrs.pad, filterShape = attrs.filterShape;
|
|
var convInfo = tf.backend_util.computeConv3DInfo(x.shape, filterShape, strides, 1 /* dilations */, pad);
|
|
var program = new Conv3DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
var conv3DBackpropFilterV2Config = {
|
|
kernelName: tf.Conv3DBackpropFilterV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3DBackpropFilterV2
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function conv3DBackpropInput(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, filter = inputs.filter;
|
|
var pad = attrs.pad, strides = attrs.strides, inputShape = attrs.inputShape;
|
|
var convInfo = tf.backend_util.computeConv3DInfo(inputShape, filter.shape, strides, 1 /* dilations */, pad);
|
|
var program = new Conv3DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
var conv3DBackpropInputConfig = {
|
|
kernelName: tf.Conv3DBackpropInputV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: conv3DBackpropInput,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var COS = CHECK_NAN_SNIPPET_UNARY + "\n return cos(x);\n";
|
|
var cos = unaryKernelFunc({ opSnippet: COS });
|
|
var cosConfig = {
|
|
kernelName: tf.Cos,
|
|
backendName: 'webgl',
|
|
kernelFunc: cos,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n";
|
|
var cosh = unaryKernelFunc({ opSnippet: COSH });
|
|
var coshConfig = {
|
|
kernelName: tf.Cosh,
|
|
backendName: 'webgl',
|
|
kernelFunc: cosh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var CropAndResizeProgram = /** @class */ (function () {
|
|
function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) {
|
|
this.variableNames = ['Image', 'Boxes', 'BoxInd'];
|
|
this.outputShape = [];
|
|
var batch = imageShape[0], imageHeight = imageShape[1], imageWidth = imageShape[2], depth = imageShape[3];
|
|
var numBoxes = boxShape[0];
|
|
var cropHeight = cropSize[0], cropWidth = cropSize[1];
|
|
this.outputShape = [numBoxes, cropHeight, cropWidth, depth];
|
|
var methodId = method === 'bilinear' ? 1 : 0;
|
|
var _a = [imageHeight - 1 + ".0", imageWidth - 1 + ".0"], inputHeightFloat = _a[0], inputWidthFloat = _a[1];
|
|
var _b = cropHeight > 1 ?
|
|
[
|
|
"" + (imageHeight - 1) / (cropHeight - 1),
|
|
'(y2-y1) * height_ratio',
|
|
"y1*" + inputHeightFloat + " + float(y)*(height_scale)",
|
|
] :
|
|
[
|
|
'0.0',
|
|
'0.0',
|
|
"0.5 * (y1+y2) * " + inputHeightFloat,
|
|
], heightRatio = _b[0], heightScale = _b[1], inY = _b[2];
|
|
var _c = cropWidth > 1 ?
|
|
[
|
|
"" + (imageWidth - 1) / (cropWidth - 1),
|
|
'(x2-x1) * width_ratio',
|
|
"x1*" + inputWidthFloat + " + float(x)*(width_scale)",
|
|
] :
|
|
[
|
|
'0.0',
|
|
'0.0',
|
|
"0.5 * (x1+x2) * " + inputWidthFloat,
|
|
], widthRatio = _c[0], widthScale = _c[1], inX = _c[2];
|
|
// Reference implementation
|
|
// tslint:disable-next-line:max-line-length
|
|
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc
|
|
this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n ";
|
|
}
|
|
return CropAndResizeProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var cropAndResize = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var image = inputs.image, boxes = inputs.boxes, boxInd = inputs.boxInd;
|
|
var cropSize = attrs.cropSize, method = attrs.method, extrapolationValue = attrs.extrapolationValue;
|
|
var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue);
|
|
return backend.runWebGLProgram(program, [image, boxes, boxInd], 'float32');
|
|
};
|
|
var cropAndResizeConfig = {
|
|
kernelName: tf.CropAndResize,
|
|
backendName: 'webgl',
|
|
kernelFunc: cropAndResize
|
|
};
|
|
|
|
var CumSumProgram = /** @class */ (function () {
|
|
function CumSumProgram(shape, exclusive, reverse) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = shape;
|
|
var rank = shape.length;
|
|
var val = exclusive ? '0.0' : "getX(" + getCoords$1(rank, 'coords') + ")";
|
|
var length = shape[shape.length - 1];
|
|
var condition = '';
|
|
var idxString = '';
|
|
// When exclusive is set, the cumsum op becomes roll op that copies the
|
|
// value from the previous index based on the direction specified by the
|
|
// reverse flag.
|
|
if (exclusive) {
|
|
condition = reverse ? "end != " + (length - 1) : 'end != 0';
|
|
idxString = reverse ? 'end + 1' : 'end - 1';
|
|
}
|
|
else {
|
|
condition = reverse ? "end + pow2 < " + length : 'end >= pow2';
|
|
idxString = (reverse ? 'end + pow2' : 'end - pow2');
|
|
}
|
|
this.userCode = "\n uniform float index;\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, 'coords') + ";\n float val = " + val + ";\n int pow2 = int(pow(2.0, index));\n if (" + condition + ") {\n int idx = " + idxString + ";\n " + getFinalCoord(rank, 'coords') + " = idx;\n val += getX(" + getCoords$1(rank, 'coords') + ");\n }\n setOutput(val);\n }\n ";
|
|
}
|
|
CumSumProgram.prototype.getCustomSetupFunc = function (index) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.index == null) {
|
|
_this.index = gpgpu.getUniformLocation(webGLProgram, 'index');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.index, index);
|
|
};
|
|
};
|
|
return CumSumProgram;
|
|
}());
|
|
function getCoords$1(rank, name) {
|
|
if (rank === 1) {
|
|
return "" + name;
|
|
}
|
|
else if (rank === 2) {
|
|
return name + ".x, " + name + ".y";
|
|
}
|
|
else if (rank === 3) {
|
|
return name + ".x, " + name + ".y, " + name + ".z";
|
|
}
|
|
else if (rank === 4) {
|
|
return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w";
|
|
}
|
|
else {
|
|
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
|
|
}
|
|
}
|
|
function getFinalCoord(rank, name) {
|
|
if (rank === 1) {
|
|
return "" + name;
|
|
}
|
|
else if (rank === 2) {
|
|
return name + ".y";
|
|
}
|
|
else if (rank === 3) {
|
|
return name + ".z";
|
|
}
|
|
else if (rank === 4) {
|
|
return name + ".w";
|
|
}
|
|
else {
|
|
throw Error("Cumulative sum for rank " + rank + " is not yet supported");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function cumsum(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, exclusive = attrs.exclusive, reverse = attrs.reverse;
|
|
var xRank = x.shape.length;
|
|
var permutation = tf.backend_util.getAxesPermutation([axis], xRank);
|
|
var permutedX = x;
|
|
if (permutation != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutation } });
|
|
}
|
|
var permutedAxis = tf.backend_util.getInnerMostAxes(1, xRank)[0];
|
|
if (permutedAxis !== xRank - 1) {
|
|
throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.shape.length - 1) + " " +
|
|
("but got axis=" + axis));
|
|
}
|
|
var size = permutedX.shape[permutedAxis];
|
|
var result = identity({ inputs: { x: permutedX }, backend: backend });
|
|
// Use cumsum parallel algorithm, ref:
|
|
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
|
|
for (var i = 0; i <= Math.ceil(Math.log2(size)) - 1; i++) {
|
|
var program = new CumSumProgram(permutedX.shape, false, reverse);
|
|
var customSetup = program.getCustomSetupFunc(i);
|
|
var prevResult = result;
|
|
result =
|
|
backend.runWebGLProgram(program, [result], result.dtype, customSetup);
|
|
backend.disposeIntermediateTensorInfo(prevResult);
|
|
}
|
|
// For exclusive cumsum, shift the end result in the direction of sum
|
|
// and add 0 to the front index.
|
|
if (exclusive) {
|
|
var program = new CumSumProgram(permutedX.shape, exclusive, reverse);
|
|
var prevResult = result;
|
|
result = backend.runWebGLProgram(program, [result], result.dtype);
|
|
backend.disposeIntermediateTensorInfo(prevResult);
|
|
}
|
|
if (permutation != null) {
|
|
var reversePermutation = tf.backend_util.getUndoAxesPermutation(permutation);
|
|
var reverseTransposedResult = transpose({ inputs: { x: result }, backend: backend, attrs: { perm: reversePermutation } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
return reverseTransposedResult;
|
|
}
|
|
return result;
|
|
}
|
|
var cumsumConfig = {
|
|
kernelName: tf.Cumsum,
|
|
backendName: 'webgl',
|
|
kernelFunc: cumsum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function denseBincount(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, weights = inputs.weights;
|
|
var size = attrs.size, binaryOutput = attrs.binaryOutput;
|
|
if (x.shape.length === 1) {
|
|
var xVals = backend.readSync(x.dataId);
|
|
var weightsVals = backend.readSync(weights.dataId);
|
|
var outVals = bincountImplCPU(xVals, weightsVals, weights.dtype, weights.shape, size);
|
|
return backend.makeTensorInfo([size], weights.dtype, outVals);
|
|
}
|
|
else if (x.shape.length === 2) {
|
|
var xBuf = backend.bufferSync(x);
|
|
var weightsBuf = backend.bufferSync(weights);
|
|
var outBuf = bincountReduceImplCPU(xBuf, weightsBuf, size, binaryOutput);
|
|
return backend.makeTensorInfo(outBuf.shape, weights.dtype, outBuf.values);
|
|
}
|
|
throw new Error("Error in denseBincount: input must be at most rank 2, but got rank" +
|
|
(x.shape.length + "."));
|
|
}
|
|
var denseBincountConfig = {
|
|
kernelName: tf.DenseBincount,
|
|
backendName: 'webgl',
|
|
kernelFunc: denseBincount
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DepthToSpaceProgram = /** @class */ (function () {
|
|
function DepthToSpaceProgram(outputShape, blockSize, dataFormat) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
this.outputShape = outputShape;
|
|
this.blockSize = blockSize;
|
|
this.dataFormat = dataFormat;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n ";
|
|
}
|
|
DepthToSpaceProgram.prototype.getHeightCoordString = function () {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return "coords[1]";
|
|
}
|
|
else {
|
|
return "coords[2]";
|
|
}
|
|
};
|
|
DepthToSpaceProgram.prototype.getWidthCoordString = function () {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return "coords[2]";
|
|
}
|
|
else {
|
|
return "coords[3]";
|
|
}
|
|
};
|
|
DepthToSpaceProgram.prototype.getDepthCoordString = function () {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return "coords[3]";
|
|
}
|
|
else {
|
|
return "coords[1]";
|
|
}
|
|
};
|
|
DepthToSpaceProgram.prototype.getOutputDepthSize = function () {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return this.outputShape[3];
|
|
}
|
|
else {
|
|
return this.outputShape[1];
|
|
}
|
|
};
|
|
DepthToSpaceProgram.prototype.getInputSamplingString = function () {
|
|
if (this.dataFormat === 'NHWC') {
|
|
return "getX(b, in_h, in_w, in_d)";
|
|
}
|
|
else {
|
|
return "getX(b, in_d, in_h, in_w)";
|
|
}
|
|
};
|
|
return DepthToSpaceProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function depthToSpace(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var blockSize = attrs.blockSize, dataFormat = attrs.dataFormat;
|
|
tf.util.assert(blockSize > 1, function () { return "blockSize should be > 1 for depthToSpace, but was: " + blockSize; });
|
|
var batchSize = x.shape[0];
|
|
var inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2];
|
|
var inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3];
|
|
var inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1];
|
|
var outputHeight = inputHeight * blockSize;
|
|
var outputWidth = inputWidth * blockSize;
|
|
var outputDepth = inputDepth / (blockSize * blockSize);
|
|
var outputShape = (dataFormat === 'NHWC') ?
|
|
[batchSize, outputHeight, outputWidth, outputDepth] :
|
|
[batchSize, outputDepth, outputHeight, outputWidth];
|
|
var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
var depthToSpaceConfig = {
|
|
kernelName: tf.DepthToSpace,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthToSpace
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DepthwiseConv2DProgram = /** @class */ (function () {
|
|
function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) {
|
|
if (addBias === void 0) { addBias = false; }
|
|
if (activation === void 0) { activation = null; }
|
|
if (hasPreluActivation === void 0) { hasPreluActivation = false; }
|
|
if (hasLeakyReluAlpha === void 0) { hasLeakyReluAlpha = false; }
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
var xNumRows = convInfo.inHeight;
|
|
var xNumCols = convInfo.inWidth;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
var activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else if (hasLeakyReluAlpha) {
|
|
activationSnippet = "float activation(float a) {\n float b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else {
|
|
activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n ";
|
|
}
|
|
applyActivationSnippet = "result = activation(result);";
|
|
}
|
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyReluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
|
|
}
|
|
return DepthwiseConv2DProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DepthwiseConvPacked2DProgram = /** @class */ (function () {
|
|
function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation, hasLeakyReluAlpha) {
|
|
if (addBias === void 0) { addBias = false; }
|
|
if (activation === void 0) { activation = null; }
|
|
if (hasPreluActivation === void 0) { hasPreluActivation = false; }
|
|
if (hasLeakyReluAlpha === void 0) { hasLeakyReluAlpha = false; }
|
|
this.variableNames = ['x', 'W'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = convInfo.outShape;
|
|
var channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
var xNumRows = convInfo.inHeight;
|
|
var xNumCols = convInfo.inWidth;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var texelsAcross = filterWidth;
|
|
var mainLoop = "\n int xR; int xC; int xCOffset;\n vec4 wTexel; vec4 previous; vec4 final;";
|
|
for (var c = 0; c < filterWidth; c++) {
|
|
mainLoop += "\n vec4 xTexelC" + c * 2 + ";\n int xTexelC" + c * 2 + "Ready;\n vec4 xC" + c + ";";
|
|
}
|
|
/**
|
|
* This vectorized implementation works by gathering the values needed for
|
|
* each output channel's dot product into vec4's and then multiplying them
|
|
* all together (this happens in the final double for-loop below). Most of
|
|
* the main loop consists of constructing these vec4's with the minimum
|
|
* number of texture2D calls, which means making use of all four returned
|
|
* values from a texture2D call at once.
|
|
*/
|
|
for (var r = 0; r < filterHeight; r++) {
|
|
for (var c = 0; c < filterWidth; c++) {
|
|
mainLoop += "\n xTexelC" + c * 2 + " = vec4(0.0);\n xTexelC" + c * 2 + "Ready = 0;\n xC" + c + " = vec4(0.0);";
|
|
}
|
|
mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n if (xR >=0 && xR < " + xNumRows + ") {\n ";
|
|
for (var texelC = 0; texelC < (texelsAcross + 1) / 2; texelC++) {
|
|
var colIndex = texelC * 2;
|
|
var c = colIndex * dilationWidth;
|
|
mainLoop += "\n xC = xCCorner + " + c + ";\n ";
|
|
if (strideWidth === 1) {
|
|
if (colIndex < filterWidth) {
|
|
// If padding is odd, the outer texels have to be composed.
|
|
if (padLeft % 2 === 1) {
|
|
// TODO: Ensure vec4 previous does not result in redundant sample,
|
|
// and avoid setting xTexelRC's that exceed the boundary in the
|
|
// first place rather than resetting them to vec4(0)).
|
|
// To compute xCOffset:
|
|
// - If padding is odd, we must add 1 to ensure we ask for an
|
|
// even-numbered row.
|
|
// - We subtract 2 to access the previous texel.
|
|
mainLoop += "\n xCOffset = xC + 1;\n if (xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + c + "Ready == 0) {\n xTexelC" + c + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= " + xNumCols + ") {\n xTexelC" + c + ".zw = vec2(0.0);\n }\n xTexelC" + c + "Ready = 1;\n }\n ";
|
|
// This texel has been read in previous iteration if the dilation
|
|
// is 1.
|
|
if (dilationWidth === 1 && c > 0) {
|
|
mainLoop += "\n xC" + colIndex + " = vec4(xTexelC" + (c - 2) + ".zw, xTexelC" + c + ".xy);\n ";
|
|
}
|
|
else {
|
|
mainLoop += "\n xCOffset = xC + 1 - 2;\n\n if (xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n previous = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= " + xNumCols + ") {\n previous.zw = vec2(0.0);\n }\n\n xC" + colIndex + " = vec4(previous.zw, xTexelC" + c + ".xy);\n } else {\n xC" + colIndex + " = vec4(0.0, 0.0, xTexelC" + c + ".xy);\n }\n ";
|
|
}
|
|
}
|
|
else {
|
|
// Padding is even, so xRC corresponds to a single texel.
|
|
mainLoop += "\n if (xC >= 0 && xC < " + xNumCols + " && xTexelC" + c + "Ready == 0) {\n xTexelC" + c + " = getX(batch, xR, xC, d1);\n if (xC + 1 >= " + xNumCols + ") {\n xTexelC" + c + ".zw = vec2(0.0);\n }\n xTexelC" + c + "Ready = 1;\n }\n\n xC" + colIndex + " = xTexelC" + c + ";\n ";
|
|
}
|
|
if (c + 1 < filterWidth) {
|
|
// If dilation is even, the second entry should match the first
|
|
// (either both are composed or both are single samples). But if
|
|
// dilation is odd, then the second entry should be the opposite
|
|
// of the first (if the first is composed, the second is a single
|
|
// sample, and vice versa.)
|
|
var nextTexelOffset = padLeft % 2 === 0 ?
|
|
tf.util.nearestLargerEven(dilationWidth) :
|
|
dilationWidth;
|
|
if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) ||
|
|
(dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) {
|
|
mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if (xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + (c + 2) + "Ready == 0) {\n xTexelC" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= " + xNumCols + ") {\n xTexelC" + (c + 2) + ".zw = vec2(0.0);\n }\n xTexelC" + (c + 2) + "Ready = 1;\n }\n ";
|
|
// If dilation > 1 then the xRC's will not be able to share any
|
|
// values, so each xRC will require two unique calls to getX.
|
|
if (dilationWidth > 1) {
|
|
mainLoop += "\n xCOffset -= 2;\n if (xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + c + "Ready == 0) {\n xTexelC" + c + " = getX(batch, xR, xCOffset, d1);\n xTexelC" + c + "Ready = 1;\n }\n ";
|
|
}
|
|
mainLoop += "\n xC" + (colIndex + 1) + " = vec4(xTexelC" + c + ".zw, xTexelC" + (c + 2) + ".xy);\n ";
|
|
}
|
|
else {
|
|
// If dilation is 1 and padding is odd, we have already read the
|
|
// texel when constructing the previous x value. Here we can
|
|
// simply skip the texture read.
|
|
if (nextTexelOffset === 1) {
|
|
mainLoop += "\n xC" + (colIndex + 1) + " = xTexelC" + c + ";\n ";
|
|
}
|
|
else {
|
|
mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if (xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + (c + 2) + "Ready == 0) {\n xTexelC" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= " + xNumCols + ") {\n xTexelC" + (c + 2) + ".zw = vec2(0.0);\n }\n xTexelC" + (c + 2) + "Ready = 1;\n }\n\n xC" + (colIndex + 1) + " = xTexelC" + (c + 2) + ";\n ";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
else { // stride === 2
|
|
if (c < filterWidth) {
|
|
// Depending on whether padLeft is even or odd, we want either the
|
|
// xy or zw channels from X texels for xC${colIndex}. If padLeft is
|
|
// even, xC${colIndex +1} is simply the zw channels of texels we've
|
|
// already sampled. But if padLeft is odd, xC{$c + 1}.zw will
|
|
// need to come from the xy channels of a new texel, hence the `
|
|
// vec4
|
|
// final` initialized below.
|
|
if (padLeft % 2 === 1) {
|
|
mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + c + "Ready == 0) {\n xTexelC" + c + " = getX(batch, xR, xCOffset, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xCOffset + 1 >= " + xNumCols + ") {\n xTexelC" + c + ".zw = vec2(0.0);\n }\n xTexelC" + c + "Ready = 1;\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + " && xTexelC" + (c + 2) + "Ready == 0) {\n xTexelC" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n // Need to manually clear unused channels in case\n // we're reading from recycled texture.\n if (xC + 2 >= " + xNumCols + ") {\n xTexelC" + (c + 2) + ".zw = vec2(0.0);\n }\n xTexelC" + (c + 2) + "Ready = 1;\n }\n\n xC" + colIndex + " = vec4(xTexelC" + c + ".zw, xTexelC" + (c + 2) + ".zw);\n ";
|
|
if (c + 1 < filterWidth) {
|
|
mainLoop += "\n final = vec4(0.0);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xC" + (colIndex + 1) + " = vec4(xTexelC" + (c + 2) + ".xy, final.xy);\n ";
|
|
}
|
|
}
|
|
else {
|
|
mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + " && xTexelC" + c + "Ready == 0) {\n xTexelC" + c + " = getX(batch, xR, xC, d1);\n if (xC + 1 >= " + xNumCols + ") {\n xTexelC" + c + ".zw = vec2(0.0);\n }\n xTexelC" + c + "Ready = 1;\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + " && xTexelC" + (c + 2) + "Ready == 0) {\n xTexelC" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n if (xCOffset + 1 >= " + xNumCols + ") {\n xTexelC" + (c + 2) + ".zw = vec2(0.);\n }\n xTexelC" + (c + 2) + "Ready = 1;\n }\n\n xC" + colIndex + " = vec4(\n xTexelC" + c + ".xy, xTexelC" + (c + 2) + ".xy);\n ";
|
|
if (c + 1 < filterWidth) {
|
|
mainLoop += "\n xC" + (colIndex + 1) + " = vec4(xTexelC" + c + ".zw, xTexelC" + (c + 2) + ".zw);\n ";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// localize the dotProd accumulation within the loop, the theory is for
|
|
// GPU with limited cache, accumulate sum across large amount of
|
|
// veriables will cause lots of cache misses. (i.e. 5x5 filter will have
|
|
// 50 variables)
|
|
if (colIndex < filterWidth) {
|
|
mainLoop += "\n wTexel = getW(" + r + ", " + c + ", d1, q);\n dotProd += xC" + colIndex + " * vec4(wTexel.xz, wTexel.xz);\n ";
|
|
if (c + 1 < filterWidth) {
|
|
mainLoop += "\n wTexel = getW(" + r + ", " + (c + 1) + ", d1, q);\n dotProd += xC" + (colIndex + 1) + " * vec4(wTexel.xz, wTexel.xz);\n ";
|
|
}
|
|
}
|
|
}
|
|
mainLoop += "\n }\n ";
|
|
}
|
|
var activationSnippet = '', applyActivationSnippet = '';
|
|
if (activation) {
|
|
if (hasPreluActivation) {
|
|
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else if (hasLeakyReluAlpha) {
|
|
activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getLeakyreluAlphaAtOutCoords();\n " + activation + "\n }";
|
|
}
|
|
else {
|
|
activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }";
|
|
}
|
|
applyActivationSnippet = "result = activation(result);";
|
|
}
|
|
var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : '';
|
|
if (addBias) {
|
|
this.variableNames.push('bias');
|
|
}
|
|
if (hasPreluActivation) {
|
|
this.variableNames.push('preluActivationWeights');
|
|
}
|
|
if (hasLeakyReluAlpha) {
|
|
this.variableNames.push('leakyreluAlpha');
|
|
}
|
|
this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n //intialize dotProd with a small epsilon seems to reduce GPU accuracy loss.\n vec4 dotProd = vec4(0.000000000000001);\n\n " + mainLoop + "\n\n vec4 result = dotProd - vec4(0.000000000000001);\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n ";
|
|
}
|
|
return DepthwiseConvPacked2DProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function depthwiseConv2dNative(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter;
|
|
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode;
|
|
var $dilations = dilations;
|
|
if ($dilations == null) {
|
|
$dilations = [1, 1];
|
|
}
|
|
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
|
("1. Got strides " + strides + " and dilations '" + $dilations + "'"); });
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
|
|
var program;
|
|
if (tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') && convInfo.strideWidth <= 2 &&
|
|
convInfo.outChannels / convInfo.inChannels === 1) {
|
|
program = new DepthwiseConvPacked2DProgram(convInfo);
|
|
}
|
|
else {
|
|
program = new DepthwiseConv2DProgram(convInfo);
|
|
}
|
|
return backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
}
|
|
var depthwiseConv2dNativeConfig = {
|
|
kernelName: tf.DepthwiseConv2dNative,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNative,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DepthwiseConv2DDerFilterProgram = /** @class */ (function () {
|
|
function DepthwiseConv2DDerFilterProgram(convInfo) {
|
|
this.variableNames = ['x', 'dy'];
|
|
this.outputShape = convInfo.filterShape;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var padTop = convInfo.padInfo.top;
|
|
var padLeft = convInfo.padInfo.left;
|
|
var channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return DepthwiseConv2DDerFilterProgram;
|
|
}());
|
|
var DepthwiseConv2DDerInputProgram = /** @class */ (function () {
|
|
function DepthwiseConv2DDerInputProgram(convInfo) {
|
|
this.variableNames = ['dy', 'W'];
|
|
this.outputShape = convInfo.inShape;
|
|
var filterHeight = convInfo.filterHeight;
|
|
var filterWidth = convInfo.filterWidth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var padTop = filterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = filterWidth - 1 - convInfo.padInfo.left;
|
|
var channelMul = convInfo.outChannels / convInfo.inChannels;
|
|
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return DepthwiseConv2DDerInputProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function depthwiseConv2dNativeBackpropFilter(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, dy = inputs.dy;
|
|
var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, filterShape = attrs.filterShape;
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filterShape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
|
var program = new DepthwiseConv2DDerFilterProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [x, dy], 'float32');
|
|
}
|
|
var depthwiseConv2dNativeBackpropFilterConfig = {
|
|
kernelName: tf.DepthwiseConv2dNativeBackpropFilter,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNativeBackpropFilter
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function depthwiseConv2dNativeBackpropInput(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, filter = inputs.filter;
|
|
var strides = attrs.strides, dilations = attrs.dilations, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode, inputShape = attrs.inputShape;
|
|
var convInfo = tf.backend_util.computeConv2DInfo(inputShape, filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */);
|
|
var program = new DepthwiseConv2DDerInputProgram(convInfo);
|
|
return backend.runWebGLProgram(program, [dy, filter], 'float32');
|
|
}
|
|
var depthwiseConv2dNativeBackpropInputConfig = {
|
|
kernelName: tf.DepthwiseConv2dNativeBackpropInput,
|
|
backendName: 'webgl',
|
|
kernelFunc: depthwiseConv2dNativeBackpropInput
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var DiagProgram = /** @class */ (function () {
|
|
function DiagProgram(size) {
|
|
this.variableNames = ['X'];
|
|
this.outputShape = [size, size];
|
|
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n ";
|
|
}
|
|
return DiagProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function diag(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
var outShape = x.shape.concat(x.shape);
|
|
var xSize = tf.util.sizeFromShape(x.shape);
|
|
var flat = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [xSize] } });
|
|
var program = new DiagProgram(xSize);
|
|
var res = backend.runWebGLProgram(program, [flat], flat.dtype);
|
|
var out = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(flat);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return out;
|
|
}
|
|
var diagConfig = {
|
|
kernelName: tf.Diag,
|
|
backendName: 'webgl',
|
|
kernelFunc: diag
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var Dilation2DProgram = /** @class */ (function () {
|
|
function Dilation2DProgram(convInfo) {
|
|
this.variableNames = ['x', 'W'];
|
|
this.outputShape = convInfo.outShape;
|
|
var inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, padInfo = convInfo.padInfo, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, dilationHeight = convInfo.dilationHeight, dilationWidth = convInfo.dilationWidth;
|
|
var padTop = padInfo.top, padLeft = padInfo.left;
|
|
this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float neg_infinity = -3.4e38;\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.w;\n ivec2 outTopLeftCorner =\n coords.yz * strides - pads;\n int hBeg = outTopLeftCorner.x;\n int wBeg = outTopLeftCorner.y;\n\n float curVal = neg_infinity;\n for (int h = 0; h < " + filterHeight + "; h++) {\n int hIn = hBeg + h * " + dilationHeight + ";\n\n if (hIn >= 0 && hIn < " + inHeight + ") {\n for (int w = 0; w < " + filterWidth + "; w++) {\n int wIn = wBeg + w * " + dilationWidth + ";\n\n if (wIn >= 0 && wIn < " + inWidth + ") {\n float xVal = getX(batch, hIn, wIn, d1);\n float wVal = getW(h, w, d1);\n\n float val = xVal + wVal;\n if (val > curVal) {\n curVal = val;\n }\n }\n }\n }\n }\n\n float result = curVal;\n setOutput(result);\n }\n ";
|
|
}
|
|
return Dilation2DProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function dilation2D(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter;
|
|
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations;
|
|
var convInfo = tf.backend_util.computeDilation2DInfo(x.shape, filter.shape, strides, pad, 'NHWC' /* dataFormat */, dilations);
|
|
var out;
|
|
var program = new Dilation2DProgram(convInfo);
|
|
out = backend.runWebGLProgram(program, [x, filter], 'float32');
|
|
var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } });
|
|
backend.disposeIntermediateTensorInfo(out);
|
|
return outReshaped;
|
|
}
|
|
var dilation2DConfig = {
|
|
kernelName: tf.Dilation2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: dilation2D,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function einsum(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var equation = attrs.equation;
|
|
var tensors = inputs;
|
|
var _a = tf.backend_util.decodeEinsumEquation(equation, tensors.length), allDims = _a.allDims, summedDims = _a.summedDims, idDims = _a.idDims;
|
|
tf.backend_util.checkEinsumDimSizes(allDims.length, idDims, tensors);
|
|
var _b = tf.backend_util.getEinsumComputePath(summedDims, idDims), path = _b.path, steps = _b.steps;
|
|
var nSteps = steps.length;
|
|
var out = null;
|
|
var numDimsRemaining = allDims.length;
|
|
var tensorsToDispose = [];
|
|
for (var i = 0; i < nSteps; ++i) {
|
|
for (var _i = 0, _c = steps[i]; _i < _c.length; _i++) {
|
|
var idTerm = _c[_i];
|
|
var _d = tf.backend_util.getEinsumPermutation(numDimsRemaining, idDims[idTerm]), perm = _d.permutationIndices, dimsToExpand = _d.expandDims;
|
|
var x = void 0;
|
|
if (tf.backend_util.isIdentityPermutation(perm)) {
|
|
x = tensors[idTerm];
|
|
}
|
|
else {
|
|
x = transpose({ inputs: { x: tensors[idTerm] }, backend: backend, attrs: { perm: perm } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
var targetShape = x.shape.slice();
|
|
for (var k = 0; k < dimsToExpand.length; ++k) {
|
|
targetShape.splice(dimsToExpand[k], 0, 1);
|
|
}
|
|
if (!tf.util.arraysEqual(x.shape, targetShape)) {
|
|
x = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: targetShape } });
|
|
tensorsToDispose.push(x);
|
|
}
|
|
if (out === null) {
|
|
out = x;
|
|
}
|
|
else {
|
|
// tslint:disable-next-line: no-unnecessary-type-assertion
|
|
out = multiply({ inputs: { a: x, b: out }, backend: backend });
|
|
tensorsToDispose.push(out);
|
|
}
|
|
}
|
|
if (i < nSteps - 1) {
|
|
if (path[i] >= 0) {
|
|
out = sum({
|
|
inputs: { x: out },
|
|
backend: backend,
|
|
attrs: {
|
|
axis: path[i] - (allDims.length - numDimsRemaining),
|
|
keepDims: false
|
|
}
|
|
});
|
|
tensorsToDispose.push(out);
|
|
}
|
|
numDimsRemaining--;
|
|
}
|
|
}
|
|
// Clean up intermediate tensors.
|
|
for (var _e = 0, tensorsToDispose_1 = tensorsToDispose; _e < tensorsToDispose_1.length; _e++) {
|
|
var tensorInfo = tensorsToDispose_1[_e];
|
|
if (tensorInfo === out) {
|
|
continue;
|
|
}
|
|
backend.disposeIntermediateTensorInfo(tensorInfo);
|
|
}
|
|
return out;
|
|
}
|
|
var einsumConfig = {
|
|
kernelName: tf.Einsum,
|
|
backendName: 'webgl',
|
|
kernelFunc: einsum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ELU$2 = "return (x >= 0.0) ? x : (exp(x) - 1.0);";
|
|
var ELU_PACKED = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n";
|
|
var elu = unaryKernelFunc({ opSnippet: ELU$2, packedOpSnippet: ELU_PACKED });
|
|
var eluConfig = {
|
|
kernelName: tf.Elu,
|
|
backendName: 'webgl',
|
|
kernelFunc: elu
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ELU_DER = "return (b >= 1.0) ? a : a * (b + 1.0);";
|
|
var ELU_DER_PACKED = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n";
|
|
var eluGrad = function (args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var dy = inputs.dy, y = inputs.y;
|
|
var program = tf.env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ?
|
|
new BinaryOpPackedProgram(ELU_DER_PACKED, dy.shape, y.shape) :
|
|
new BinaryOpProgram(ELU_DER, dy.shape, y.shape);
|
|
return backend.runWebGLProgram(program, [dy, y], dy.dtype);
|
|
};
|
|
var eluGradConfig = {
|
|
kernelName: tf.EluGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: eluGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PACKED_EQUAL = "\n return vec4(equal(a, b));\n";
|
|
var EQUAL = "return float(a == b);";
|
|
var equal = binaryKernelFunc({
|
|
opSnippet: EQUAL,
|
|
packedOpSnippet: PACKED_EQUAL,
|
|
dtype: 'bool',
|
|
cpuKernelImpl: equalImplCPU,
|
|
});
|
|
var equalConfig = {
|
|
kernelName: tf.Equal,
|
|
backendName: 'webgl',
|
|
kernelFunc: equal
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ERF = "\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = " + tf.backend_util.ERF_P + ";\n float a1 = " + tf.backend_util.ERF_A1 + ";\n float a2 = " + tf.backend_util.ERF_A2 + ";\n float a3 = " + tf.backend_util.ERF_A3 + ";\n float a4 = " + tf.backend_util.ERF_A4 + ";\n float a5 = " + tf.backend_util.ERF_A5 + ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n";
|
|
var erf = unaryKernelFunc({ opSnippet: ERF });
|
|
var erfConfig = {
|
|
kernelName: tf.Erf,
|
|
backendName: 'webgl',
|
|
kernelFunc: erf,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var EXP = "return exp(x);";
|
|
var exp = unaryKernelFunc({ opSnippet: EXP, packedOpSnippet: EXP, cpuKernelImpl: expImplCPU });
|
|
var expConfig = {
|
|
kernelName: tf.Exp,
|
|
backendName: 'webgl',
|
|
kernelFunc: exp
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function expandDims(args) {
|
|
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
|
|
var dim = attrs.dim;
|
|
var input = inputs.input;
|
|
var inputRank = input.shape.length;
|
|
var newShape = input.shape.slice();
|
|
var $dim = dim;
|
|
if (dim < 0) {
|
|
// Negative value is counted from the tail of rank.
|
|
tf.util.assert(-(inputRank + 1) <= dim, function () { return "Axis must be in the interval [" + -(inputRank + 1) + ", " + inputRank + "]"; });
|
|
$dim = inputRank + dim + 1;
|
|
}
|
|
newShape.splice($dim, 0, 1);
|
|
return reshape({ inputs: { x: input }, backend: backend, attrs: { shape: newShape } });
|
|
}
|
|
var expandDimsConfig = {
|
|
kernelName: tf.ExpandDims,
|
|
backendName: 'webgl',
|
|
kernelFunc: expandDims,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var EXPM1 = "return exp(x) - 1.0;";
|
|
var expm1 = unaryKernelFunc({ opSnippet: EXPM1, packedOpSnippet: EXPM1, cpuKernelImpl: expm1ImplCPU });
|
|
var expm1Config = {
|
|
kernelName: tf.Expm1,
|
|
backendName: 'webgl',
|
|
kernelFunc: expm1
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FFTProgram = /** @class */ (function () {
|
|
function FFTProgram(component, inputShape, inverse) {
|
|
this.variableNames = ['real', 'imag'];
|
|
var innerDim = inputShape[1];
|
|
this.outputShape = inputShape;
|
|
var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI;
|
|
var resultDenominator = inverse ? innerDim + ".0" : '1.0';
|
|
var opString;
|
|
if (component === 'real') {
|
|
opString = 'return real * expR - imag * expI;';
|
|
}
|
|
else if (component === 'imag') {
|
|
opString = 'return real * expI + imag * expR;';
|
|
}
|
|
else {
|
|
throw new Error("FFT component must be either \"real\" or \"imag\", got " + component + ".");
|
|
}
|
|
this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + opString + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n ";
|
|
}
|
|
return FFTProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function fftImpl(x, inverse, backend) {
|
|
var xData = backend.texData.get(x.dataId);
|
|
var inputSize = tf.util.sizeFromShape(x.shape);
|
|
// Collapse all outer dimensions to a single batch dimension.
|
|
var innerDimensionSize = x.shape[x.shape.length - 1];
|
|
var batch = inputSize / innerDimensionSize;
|
|
var input2D = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: [batch, innerDimensionSize] } });
|
|
var xShape = input2D.shape;
|
|
var realProgram = new FFTProgram('real', xShape, inverse);
|
|
var imagProgram = new FFTProgram('imag', xShape, inverse);
|
|
var inputs = [
|
|
{
|
|
dataId: xData.complexTensorInfos.real.dataId,
|
|
dtype: xData.complexTensorInfos.real.dtype,
|
|
shape: xShape
|
|
},
|
|
{
|
|
dataId: xData.complexTensorInfos.imag.dataId,
|
|
dtype: xData.complexTensorInfos.imag.dtype,
|
|
shape: xShape
|
|
}
|
|
];
|
|
var realPart = backend.runWebGLProgram(realProgram, inputs, 'float32');
|
|
var imagPart = backend.runWebGLProgram(imagProgram, inputs, 'float32');
|
|
var complexOutput = complex({ inputs: { real: realPart, imag: imagPart }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
var complexOutputReshaped = reshape({ inputs: { x: complexOutput }, backend: backend, attrs: { shape: x.shape } });
|
|
backend.disposeIntermediateTensorInfo(input2D);
|
|
backend.disposeIntermediateTensorInfo(complexOutput);
|
|
return complexOutputReshaped;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function fft(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var input = inputs.input;
|
|
return fftImpl(input, false /* inverse */, backend);
|
|
}
|
|
var fftConfig = {
|
|
kernelName: tf.FFT,
|
|
backendName: 'webgl',
|
|
kernelFunc: fft
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FillProgram = /** @class */ (function () {
|
|
function FillProgram(shape, value) {
|
|
this.outputShape = [];
|
|
this.variableNames = ['x'];
|
|
this.outputShape = shape;
|
|
this.userCode = "\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n ";
|
|
}
|
|
FillProgram.prototype.getCustomSetupFunc = function (value) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.valueLoc == null) {
|
|
_this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.valueLoc, value);
|
|
};
|
|
};
|
|
return FillProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function fill(args) {
|
|
var backend = args.backend, attrs = args.attrs;
|
|
var shape = attrs.shape, value = attrs.value;
|
|
var dtype = attrs.dtype;
|
|
dtype = dtype || tf.util.inferDtype(value);
|
|
if (dtype === 'string') {
|
|
// String type should be handled in CPU memory.
|
|
var values = tf.util.getArrayFromDType(dtype, tf.util.sizeFromShape(shape));
|
|
values.fill(value);
|
|
return backend.makeTensorInfo(shape, dtype, values);
|
|
}
|
|
else {
|
|
var program = new FillProgram(shape, value);
|
|
var customSetup = program.getCustomSetupFunc(value);
|
|
return backend.runWebGLProgram(program, [], dtype, customSetup);
|
|
}
|
|
}
|
|
var fillConfig = {
|
|
kernelName: tf.Fill,
|
|
backendName: 'webgl',
|
|
kernelFunc: fill
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FlipLeftRightProgram = /** @class */ (function () {
|
|
function FlipLeftRightProgram(imageShape) {
|
|
this.variableNames = ['Image'];
|
|
this.outputShape = [];
|
|
var imageWidth = imageShape[2];
|
|
this.outputShape = imageShape;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n\n int coordX = " + imageWidth + " - x;\n float outputValue;\n if(coordX >= 0 && coordX < " + imageWidth + ") {\n outputValue = getImage(coords[0], coords[1], coordX, coords[3]);\n } else {\n outputValue = getImage(coords[0], coords[1], coords[2], coords[3]);\n }\n setOutput(outputValue);\n }\n ";
|
|
}
|
|
return FlipLeftRightProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var flipLeftRightConfig = {
|
|
kernelName: tf.FlipLeftRight,
|
|
backendName: 'webgl',
|
|
kernelFunc: function (_a) {
|
|
var inputs = _a.inputs, backend = _a.backend;
|
|
var image = inputs.image;
|
|
var webglBackend = backend;
|
|
var program = new FlipLeftRightProgram(image.shape);
|
|
var output = webglBackend.runWebGLProgram(program, [image], image.dtype);
|
|
return output;
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FLOOR = "return floor(x);";
|
|
var floor = unaryKernelFunc({ opSnippet: FLOOR, packedOpSnippet: FLOOR, cpuKernelImpl: floorImplCPU });
|
|
var floorConfig = {
|
|
kernelName: tf.Floor,
|
|
backendName: 'webgl',
|
|
kernelFunc: floor,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// We use native integer division to deal with floating point imprecision. Since
|
|
// we implement floor division and glsl implements truncated division, we
|
|
// correct for this by subtracting 1 from result when the result is negative and
|
|
// there is a remainder.
|
|
var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n";
|
|
var INT_DIV_PACKED = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n";
|
|
var floorDiv = binaryKernelFunc({ opSnippet: INT_DIV, packedOpSnippet: INT_DIV_PACKED, dtype: 'int32' });
|
|
var floorDivConfig = {
|
|
kernelName: tf.FloorDiv,
|
|
backendName: 'webgl',
|
|
kernelFunc: floorDiv
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FromPixelsProgram = /** @class */ (function () {
|
|
function FromPixelsProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
var glsl = getGlslDifferences();
|
|
var height = outputShape[0], width = outputShape[1];
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n ";
|
|
}
|
|
return FromPixelsProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var FromPixelsPackedProgram = /** @class */ (function () {
|
|
function FromPixelsPackedProgram(outputShape) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = false;
|
|
this.packedOutput = true;
|
|
var glsl = getGlslDifferences();
|
|
var height = outputShape[0], width = outputShape[1];
|
|
this.outputShape = outputShape;
|
|
this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n " + glsl.output + " = result;\n }\n ";
|
|
}
|
|
return FromPixelsPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var fromPixelsConfig = {
|
|
kernelName: tf.FromPixels,
|
|
backendName: 'webgl',
|
|
kernelFunc: fromPixels,
|
|
};
|
|
var fromPixels2DContext;
|
|
function fromPixels(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var pixels = inputs.pixels;
|
|
var numChannels = attrs.numChannels;
|
|
var isVideo = typeof (HTMLVideoElement) !== 'undefined' &&
|
|
pixels instanceof HTMLVideoElement;
|
|
var isImage = typeof (HTMLImageElement) !== 'undefined' &&
|
|
pixels instanceof HTMLImageElement;
|
|
var _a = isVideo ?
|
|
[
|
|
pixels.videoWidth,
|
|
pixels.videoHeight
|
|
] :
|
|
[pixels.width, pixels.height], width = _a[0], height = _a[1];
|
|
var texShape = [height, width];
|
|
var outShape = [height, width, numChannels];
|
|
if (isImage || isVideo) {
|
|
if (fromPixels2DContext == null) {
|
|
fromPixels2DContext = document.createElement('canvas').getContext('2d');
|
|
}
|
|
fromPixels2DContext.canvas.width = width;
|
|
fromPixels2DContext.canvas.height = height;
|
|
fromPixels2DContext.drawImage(pixels, 0, 0, width, height);
|
|
pixels = fromPixels2DContext.canvas;
|
|
}
|
|
var tempPixelHandle = backend.makeTensorInfo(texShape, 'int32');
|
|
// This is a byte texture with pixels.
|
|
backend.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS;
|
|
backend.gpgpu.uploadPixelDataToTexture(backend.getTexture(tempPixelHandle.dataId), pixels);
|
|
var program = tf.env().getBool('WEBGL_PACK') ?
|
|
new FromPixelsPackedProgram(outShape) :
|
|
new FromPixelsProgram(outShape);
|
|
var res = backend.runWebGLProgram(program, [tempPixelHandle], 'int32');
|
|
backend.disposeData(tempPixelHandle.dataId);
|
|
return res;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function fusedConv2d(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
|
|
var strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha;
|
|
var $dataFormat = tf.backend_util.convertConv2DDataFormat(dataFormat);
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, dilations, pad, dimRoundingMode, false /* depthwise */, $dataFormat);
|
|
var out;
|
|
var intermediates = [];
|
|
if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 &&
|
|
convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 &&
|
|
convInfo.strideHeight === 1 && convInfo.strideWidth === 1 &&
|
|
(convInfo.padInfo.type === 'SAME' || convInfo.padInfo.type === 'VALID')) {
|
|
out = conv2dByMatMul({
|
|
x: x,
|
|
filter: filter,
|
|
convInfo: convInfo,
|
|
backend: backend,
|
|
bias: bias,
|
|
activation: activation,
|
|
preluActivationWeights: preluActivationWeights,
|
|
leakyreluAlpha: leakyreluAlpha
|
|
});
|
|
}
|
|
else if (tf.env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) {
|
|
out = conv2dWithIm2Row({
|
|
x: x,
|
|
filter: filter,
|
|
convInfo: convInfo,
|
|
backend: backend,
|
|
bias: bias,
|
|
activation: activation,
|
|
preluActivationWeights: preluActivationWeights,
|
|
leakyreluAlpha: leakyreluAlpha
|
|
});
|
|
}
|
|
else {
|
|
var hasBias = bias != null;
|
|
var hasPreluActivationWeights = preluActivationWeights != null;
|
|
var hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
var fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null;
|
|
var program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
var inputs_1 = [x, filter];
|
|
if (bias) {
|
|
inputs_1.push(bias);
|
|
}
|
|
if (preluActivationWeights) {
|
|
inputs_1.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32'));
|
|
inputs_1.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
out = backend.runWebGLProgram(program, inputs_1, 'float32');
|
|
}
|
|
var outReshaped = reshape({ inputs: { x: out }, backend: backend, attrs: { shape: convInfo.outShape } });
|
|
intermediates.push(out);
|
|
intermediates.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return outReshaped;
|
|
}
|
|
var fusedConv2DConfig = {
|
|
kernelName: tf.FusedConv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: fusedConv2d,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function fusedDepthwiseConv2D(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, filter = inputs.filter, bias = inputs.bias, preluActivationWeights = inputs.preluActivationWeights;
|
|
var strides = attrs.strides, pad = attrs.pad, dilations = attrs.dilations, dimRoundingMode = attrs.dimRoundingMode, activation = attrs.activation, leakyreluAlpha = attrs.leakyreluAlpha;
|
|
var intermediates = [];
|
|
var $dilations = dilations;
|
|
if ($dilations == null) {
|
|
$dilations = [1, 1];
|
|
}
|
|
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, $dilations), function () { return 'Error in depthwiseConv2d: Either strides or dilations must be ' +
|
|
("1. Got strides " + strides + " and dilations '" + $dilations + "'"); });
|
|
var convInfo = tf.backend_util.computeConv2DInfo(x.shape, filter.shape, strides, $dilations, pad, dimRoundingMode, true /* depthwise */);
|
|
var shouldPackDepthwiseConv = tf.env().getBool('WEBGL_PACK_DEPTHWISECONV') &&
|
|
convInfo.strideWidth <= 2 &&
|
|
convInfo.outChannels / convInfo.inChannels === 1;
|
|
var fusedActivation = activation ?
|
|
mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) :
|
|
null;
|
|
var programInputs = [x, filter];
|
|
var hasBias = bias != null;
|
|
var hasPreluActivationWeights = preluActivationWeights != null;
|
|
var hasLeakyreluAlpha = activation === 'leakyrelu';
|
|
if (hasBias) {
|
|
programInputs.push(bias);
|
|
}
|
|
if (hasPreluActivationWeights) {
|
|
programInputs.push(preluActivationWeights);
|
|
}
|
|
if (hasLeakyreluAlpha) {
|
|
var $leakyreluAlpha = backend.makeTensorInfo([], 'float32', tf.util.createScalarValue(leakyreluAlpha, 'float32'));
|
|
programInputs.push($leakyreluAlpha);
|
|
intermediates.push($leakyreluAlpha);
|
|
}
|
|
var program;
|
|
if (shouldPackDepthwiseConv) {
|
|
program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
}
|
|
else {
|
|
program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights, hasLeakyreluAlpha);
|
|
}
|
|
var result = backend.runWebGLProgram(program, programInputs, 'float32');
|
|
intermediates.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return result;
|
|
}
|
|
var fusedDepthwiseConv2DConfig = {
|
|
kernelName: tf.FusedDepthwiseConv2D,
|
|
backendName: 'webgl',
|
|
kernelFunc: fusedDepthwiseConv2D,
|
|
};
|
|
|
|
var GatherNDProgram = /** @class */ (function () {
|
|
function GatherNDProgram(sliceDim, strides, shape) {
|
|
this.sliceDim = sliceDim;
|
|
this.strides = strides;
|
|
this.variableNames = ['x', 'indices'];
|
|
this.outputShape = shape;
|
|
var stridesType = getCoordsDataType(strides.length);
|
|
var dtype = getCoordsDataType(shape.length);
|
|
var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides';
|
|
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n ";
|
|
}
|
|
return GatherNDProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function gatherNd(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var params = inputs.params, indices = inputs.indices;
|
|
var indicesShape = indices.shape;
|
|
var sliceRank = indicesShape[indicesShape.length - 1];
|
|
var paramsSize = tf.util.sizeFromShape(params.shape);
|
|
var _a = tf.backend_util.prepareAndValidate(params, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3];
|
|
var flattenIndices = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [numSlices, sliceRank] } });
|
|
var flattenX = reshape({
|
|
inputs: { x: params },
|
|
backend: backend,
|
|
attrs: { shape: [(tf.util.sizeFromShape(params.shape) / sliceSize), sliceSize] }
|
|
});
|
|
if (backend.shouldExecuteOnCPU([params, indices]) ||
|
|
params.dtype === 'string') {
|
|
var indicesData = backend.readSync(indices.dataId);
|
|
var paramsBuf = backend.bufferSync(params);
|
|
var outValue = gatherNdImplCPU(indicesData, paramsBuf, params.dtype, numSlices, sliceRank, sliceSize, strides, params.shape, paramsSize);
|
|
return backend.makeTensorInfo(resultShape, params.dtype, outValue.values);
|
|
}
|
|
var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]);
|
|
var res = backend.runWebGLProgram(program, [flattenX, flattenIndices], flattenX.dtype);
|
|
var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: resultShape } });
|
|
backend.disposeIntermediateTensorInfo(flattenIndices);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return reshaped;
|
|
}
|
|
var gatherNdConfig = {
|
|
kernelName: tf.GatherNd,
|
|
backendName: 'webgl',
|
|
kernelFunc: gatherNd
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var GatherProgram = /** @class */ (function () {
|
|
function GatherProgram(aShape, outputShape) {
|
|
this.variableNames = ['A', 'indices'];
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var sourceCoords = getSourceCoords$1(aShape);
|
|
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
|
|
}
|
|
return GatherProgram;
|
|
}());
|
|
// The input and output are always flattened into rank 4 tensors.
|
|
function getSourceCoords$1(aShape, axis) {
|
|
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
|
|
var sourceCoords = [];
|
|
for (var i = 0; i < aShape.length; i++) {
|
|
if (i === 2) {
|
|
sourceCoords.push('int(getIndices(resRC.x, resRC.z))');
|
|
}
|
|
else {
|
|
sourceCoords.push("" + currentCoords[i]);
|
|
}
|
|
}
|
|
return sourceCoords.join();
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function gatherV2(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, indices = inputs.indices;
|
|
var axis = attrs.axis, batchDims = attrs.batchDims;
|
|
var parsedAxis = tf.util.parseAxisParam(axis, x.shape)[0];
|
|
var shapeInfo = tf.backend_util.segment_util.collectGatherOpShapeInfo(x, indices, parsedAxis, batchDims);
|
|
var indicesSize = tf.util.sizeFromShape(indices.shape);
|
|
var toDispose = [];
|
|
var flattenX = reshape({
|
|
inputs: { x: x },
|
|
backend: backend,
|
|
attrs: {
|
|
shape: [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, shapeInfo.dimSize,
|
|
shapeInfo.sliceSize
|
|
]
|
|
}
|
|
});
|
|
var flattenIndex = reshape({
|
|
inputs: { x: indices },
|
|
backend: backend,
|
|
attrs: { shape: [shapeInfo.batchSize, indicesSize / shapeInfo.batchSize] }
|
|
});
|
|
toDispose.push(flattenX);
|
|
toDispose.push(flattenIndex);
|
|
var flattenOutputShape = [
|
|
shapeInfo.batchSize, shapeInfo.outerSize, indicesSize / shapeInfo.batchSize,
|
|
shapeInfo.sliceSize
|
|
];
|
|
if (backend.shouldExecuteOnCPU([x, indices]) || x.dtype === 'string') {
|
|
var indicesBuf = backend.bufferSync(flattenIndex);
|
|
var xBuf = backend.bufferSync(flattenX);
|
|
var outBuf = gatherV2ImplCPU(xBuf, indicesBuf, flattenOutputShape);
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return backend.makeTensorInfo(shapeInfo.outputShape, outBuf.dtype, outBuf.values);
|
|
}
|
|
var program = new GatherProgram(flattenX.shape, flattenOutputShape);
|
|
var res = backend.runWebGLProgram(program, [flattenX, flattenIndex], flattenX.dtype);
|
|
toDispose.push(res);
|
|
var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: shapeInfo.outputShape } });
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return reshaped;
|
|
}
|
|
var gatherV2Config = {
|
|
kernelName: tf.GatherV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: gatherV2
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var GREATER = "return float(a > b);";
|
|
var GREATER_PACKED = "\n return vec4(greaterThan(a, b));\n";
|
|
var greater = binaryKernelFunc({
|
|
opSnippet: GREATER,
|
|
packedOpSnippet: GREATER_PACKED,
|
|
cpuKernelImpl: greaterImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
var greaterConfig = {
|
|
kernelName: tf.Greater,
|
|
backendName: 'webgl',
|
|
kernelFunc: greater
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var GREATER_EQUAL = "return float(a >= b);";
|
|
var GREATER_EQUAL_PACKED = "\n return vec4(greaterThanEqual(a, b));\n";
|
|
var greaterEqual = binaryKernelFunc({
|
|
opSnippet: GREATER_EQUAL,
|
|
packedOpSnippet: GREATER_EQUAL_PACKED,
|
|
dtype: 'bool',
|
|
cpuKernelImpl: greaterEqualImplCPU
|
|
});
|
|
var greaterEqualConfig = {
|
|
kernelName: tf.GreaterEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: greaterEqual
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function ifft(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var input = inputs.input;
|
|
return fftImpl(input, true /* inverse */, backend);
|
|
}
|
|
var ifftConfig = {
|
|
kernelName: tf.IFFT,
|
|
backendName: 'webgl',
|
|
kernelFunc: ifft
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var IS_FINITE = "return float(!isnan(x) && !isinf(x));";
|
|
var isFinite = unaryKernelFunc({ opSnippet: IS_FINITE, dtype: 'bool' });
|
|
var isFiniteConfig = {
|
|
kernelName: tf.IsFinite,
|
|
backendName: 'webgl',
|
|
kernelFunc: isFinite,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var IS_INF = "return float(isinf(x));";
|
|
var isInf = unaryKernelFunc({ opSnippet: IS_INF, dtype: 'bool' });
|
|
var isInfConfig = {
|
|
kernelName: tf.IsInf,
|
|
backendName: 'webgl',
|
|
kernelFunc: isInf,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var IS_NAN = "return float(isnan(x));";
|
|
var isNaN = unaryKernelFunc({ opSnippet: IS_NAN, dtype: 'bool' });
|
|
var isNaNConfig = {
|
|
kernelName: tf.IsNan,
|
|
backendName: 'webgl',
|
|
kernelFunc: isNaN,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LESS = "return float(a < b);";
|
|
var LESS_PACKED = "\n return vec4(lessThan(a, b));\n";
|
|
var less = binaryKernelFunc({
|
|
opSnippet: LESS,
|
|
packedOpSnippet: LESS_PACKED,
|
|
cpuKernelImpl: lessImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
var lessConfig = {
|
|
kernelName: tf.Less,
|
|
backendName: 'webgl',
|
|
kernelFunc: less
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LESS_EQUAL = "return float(a <= b);";
|
|
var LESS_EQUAL_PACKED = "\n return vec4(lessThanEqual(a, b));\n";
|
|
var lessEqual = binaryKernelFunc({
|
|
opSnippet: LESS_EQUAL,
|
|
packedOpSnippet: LESS_EQUAL_PACKED,
|
|
cpuKernelImpl: lessEqualImplCPU,
|
|
dtype: 'bool'
|
|
});
|
|
var lessEqualConfig = {
|
|
kernelName: tf.LessEqual,
|
|
backendName: 'webgl',
|
|
kernelFunc: lessEqual
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function linSpace(args) {
|
|
var backend = args.backend, attrs = args.attrs;
|
|
var start = attrs.start, stop = attrs.stop, num = attrs.num;
|
|
// TODO: Use CPU implementation due to the precision problem in Safari.
|
|
var outVals = linSpaceImplCPU(start, stop, num);
|
|
return backend.makeTensorInfo([outVals.length], 'float32', outVals);
|
|
}
|
|
var linSpaceConfig = {
|
|
kernelName: tf.LinSpace,
|
|
backendName: 'webgl',
|
|
kernelFunc: linSpace
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LOG = "if (x < 0.0) return NAN;\n return log(x);";
|
|
var LOG_PACKED = "\n vec4 result = log(x);\n vec4 isNaN = vec4(lessThan(x, vec4(0.0)));\n result.r = isNaN.r == 1.0 ? NAN : result.r;\n result.g = isNaN.g == 1.0 ? NAN : result.g;\n result.b = isNaN.b == 1.0 ? NAN : result.b;\n result.a = isNaN.a == 1.0 ? NAN : result.a;\n\n return result;\n";
|
|
var log = unaryKernelFunc({ opSnippet: LOG, packedOpSnippet: LOG_PACKED, cpuKernelImpl: logImplCPU });
|
|
var logConfig = {
|
|
kernelName: tf.Log,
|
|
backendName: 'webgl',
|
|
kernelFunc: log
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LOG1P = "return log(1.0 + x);";
|
|
var log1p = unaryKernelFunc({ opSnippet: LOG1P });
|
|
var log1pConfig = {
|
|
kernelName: tf.Log1p,
|
|
backendName: 'webgl',
|
|
kernelFunc: log1p,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);";
|
|
var LOGICAL_AND_PACKED = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n";
|
|
var logicalAnd = binaryKernelFunc({
|
|
opSnippet: LOGICAL_AND,
|
|
packedOpSnippet: LOGICAL_AND_PACKED,
|
|
dtype: 'bool'
|
|
});
|
|
var logicalAndConfig = {
|
|
kernelName: tf.LogicalAnd,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalAnd
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LOGICAL_NOT = "return float(!(x >= 1.0));";
|
|
var logicalNot = unaryKernelFunc({ opSnippet: LOGICAL_NOT });
|
|
var logicalNotConfig = {
|
|
kernelName: tf.LogicalNot,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalNot,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);";
|
|
var LOGICAL_OR_PACKED = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n";
|
|
var logicalOr = binaryKernelFunc({ opSnippet: LOGICAL_OR, packedOpSnippet: LOGICAL_OR_PACKED, dtype: 'bool' });
|
|
var logicalOrConfig = {
|
|
kernelName: tf.LogicalOr,
|
|
backendName: 'webgl',
|
|
kernelFunc: logicalOr
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LRNProgram = /** @class */ (function () {
|
|
function LRNProgram(xShape, radius, bias, alpha, beta) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
var rad = radius;
|
|
var maxD = xShape[3] - 1;
|
|
this.outputShape = xShape;
|
|
// optimize pow(bias + alpha * sum, -beta)
|
|
// src: https://github.com/tensorflow/tensorflow/..
|
|
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
|
|
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
|
|
var powOperator;
|
|
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
|
|
if (beta === 0.5) {
|
|
powOperator = "inversesqrt(" + basis + ")";
|
|
}
|
|
else if (beta === 1.0) {
|
|
powOperator = "1.0/(" + basis + ")";
|
|
}
|
|
else {
|
|
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
|
|
}
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n ";
|
|
}
|
|
return LRNProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LRNPackedProgram = /** @class */ (function () {
|
|
function LRNPackedProgram(xShape, radius, bias, alpha, beta) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = [];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
var rad = radius;
|
|
var maxD = xShape[3] - 1;
|
|
this.outputShape = xShape;
|
|
// optimize pow(bias + alpha * sum, -beta)
|
|
// src: https://github.com/tensorflow/tensorflow/..
|
|
// blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/..
|
|
// tensorflow/core/kernels/mkl_lrn_op.cc#L320
|
|
var powOperator;
|
|
var basis = "float(" + bias + ") + float(" + alpha + ") * sum";
|
|
if (beta === 0.5) {
|
|
powOperator = "inversesqrt(" + basis + ")";
|
|
}
|
|
else if (beta === 1.0) {
|
|
powOperator = "1.0/(" + basis + ")";
|
|
}
|
|
else {
|
|
powOperator = "exp(log(" + basis + ") * float(-" + beta + "));";
|
|
}
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n ";
|
|
}
|
|
return LRNPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var lrn = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var depthRadius = attrs.depthRadius, bias = attrs.bias, alpha = attrs.alpha, beta = attrs.beta;
|
|
var program = tf.env().getBool('WEBGL_PACK_NORMALIZATION') ?
|
|
new LRNPackedProgram(x.shape, depthRadius, bias, alpha, beta) :
|
|
new LRNProgram(x.shape, depthRadius, bias, alpha, beta);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
};
|
|
// tslint:disable-next-line: variable-name
|
|
var LRNConfig = {
|
|
kernelName: tf.LRN,
|
|
backendName: 'webgl',
|
|
kernelFunc: lrn
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var LRNGradProgram = /** @class */ (function () {
|
|
function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) {
|
|
this.variableNames = ['inputImage', 'outputImage', 'dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
this.depth = inputShape[3];
|
|
this.depthRadius = depthRadius;
|
|
this.bias = bias;
|
|
this.alpha = alpha;
|
|
this.beta = beta;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n ";
|
|
}
|
|
return LRNGradProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var lrnGrad = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, y = inputs.y, dy = inputs.dy;
|
|
var depthRadius = attrs.depthRadius, bias = attrs.bias, alpha = attrs.alpha, beta = attrs.beta;
|
|
var program = new LRNGradProgram(x.shape, depthRadius, bias, alpha, beta);
|
|
return backend.runWebGLProgram(program, [x, y, dy], x.dtype);
|
|
};
|
|
// tslint:disable-next-line: variable-name
|
|
var LRNGradConfig = {
|
|
kernelName: tf.LRNGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: lrnGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxImpl$1(x, reduceShape, outShape, backend) {
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var xSize = tf.util.sizeFromShape(x.shape);
|
|
var batchSize = xSize / inSize;
|
|
var reshapedInput = reshape({ inputs: { x: x }, attrs: { shape: [batchSize, inSize] }, backend: backend });
|
|
var reduced = reduce(reshapedInput, x.dtype, 'max', backend);
|
|
var reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
return reshapedOutput;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function max(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var reductionIndices = attrs.reductionIndices, keepDims = attrs.keepDims;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(reductionIndices, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var maxInputIsTransposed = permutedAxes != null;
|
|
var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([x]);
|
|
var maxInput = x;
|
|
if (maxInputIsTransposed) {
|
|
if (shouldExecuteOnCPU) {
|
|
var xTexData = backend.texData.get(maxInput.dataId);
|
|
var values = xTexData.values;
|
|
var newShape = new Array(xRank);
|
|
for (var i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[permutedAxes[i]];
|
|
}
|
|
var maxInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
|
|
maxInput = backend.makeTensorInfo(newShape, x.dtype);
|
|
var maxInputData = backend.texData.get(maxInput.dataId);
|
|
maxInputData.values = maxInputValues;
|
|
}
|
|
else {
|
|
maxInput = transposeImpl$1(x, permutedAxes, backend);
|
|
}
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('max', axes, xRank);
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(maxInput.shape, axes), maxOutShape = _a[0], reduceShape = _a[1];
|
|
var outShape = maxOutShape;
|
|
if (keepDims) {
|
|
// rather than reshape at the end, set the target shape here.
|
|
outShape = tf.backend_util.expandShapeToKeepDim(maxOutShape, origAxes);
|
|
}
|
|
var out;
|
|
if (shouldExecuteOnCPU) {
|
|
var xTexData = backend.texData.get(maxInput.dataId);
|
|
var values = xTexData.values;
|
|
var outValues = maxImplCPU(values, tf.util.sizeFromShape(reduceShape), outShape, x.dtype);
|
|
out = backend.makeTensorInfo(outShape, x.dtype);
|
|
var outData = backend.texData.get(out.dataId);
|
|
outData.values = outValues;
|
|
}
|
|
else {
|
|
out = maxImpl$1(maxInput, reduceShape, outShape, backend);
|
|
}
|
|
if (maxInputIsTransposed) {
|
|
backend.disposeIntermediateTensorInfo(maxInput);
|
|
}
|
|
return out;
|
|
}
|
|
var maxConfig = {
|
|
kernelName: tf.Max,
|
|
backendName: 'webgl',
|
|
kernelFunc: max
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MAXIMUM = CHECK_NAN_SNIPPET$1 + "\n return max(a, b);\n";
|
|
var MAXIMUM_PACKED = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
|
|
CHECK_NAN_SNIPPET$2 + "\n return result;\n";
|
|
var maximum = binaryKernelFunc({
|
|
opSnippet: MAXIMUM,
|
|
packedOpSnippet: MAXIMUM_PACKED,
|
|
cpuKernelImpl: maximumImplCPU
|
|
});
|
|
var maximumConfig = {
|
|
kernelName: tf.Maximum,
|
|
backendName: 'webgl',
|
|
kernelFunc: maximum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxPool(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
assertNotComplex(x, 'maxPool');
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
|
|
var dilations = 1;
|
|
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' +
|
|
("Got strides " + strides + " and dilations '" + dilations + "'"); });
|
|
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 &&
|
|
tf.util.arraysEqual(convInfo.inShape, convInfo.outShape)) {
|
|
return identity({ inputs: { x: x }, backend: backend });
|
|
}
|
|
var maxPoolProgram = new Pool2DProgram(convInfo, 'max', false);
|
|
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
|
|
}
|
|
var maxPoolConfig = {
|
|
kernelName: tf.MaxPool,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxPool3d(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dataFormat = attrs.dataFormat, dimRoundingMode = attrs.dimRoundingMode;
|
|
var dilations = [1, 1, 1];
|
|
var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat);
|
|
var maxPoolProgram = new Pool3DProgram(convInfo, 'max', false);
|
|
return backend.runWebGLProgram(maxPoolProgram, [x], x.dtype);
|
|
}
|
|
var maxPool3DConfig = {
|
|
kernelName: tf.MaxPool3D,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool3d
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MaxPool2DBackpropProgram = /** @class */ (function () {
|
|
function MaxPool2DBackpropProgram(convInfo) {
|
|
this.variableNames = ['dy', 'maxPos'];
|
|
this.outputShape = convInfo.inShape;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1;
|
|
this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return MaxPool2DBackpropProgram;
|
|
}());
|
|
var MaxPool3DBackpropProgram = /** @class */ (function () {
|
|
function MaxPool3DBackpropProgram(convInfo) {
|
|
this.variableNames = ['dy', 'maxPos'];
|
|
this.outputShape = convInfo.inShape;
|
|
var strideDepth = convInfo.strideDepth;
|
|
var strideHeight = convInfo.strideHeight;
|
|
var strideWidth = convInfo.strideWidth;
|
|
var dilationDepth = convInfo.dilationDepth;
|
|
var dilationHeight = convInfo.dilationHeight;
|
|
var dilationWidth = convInfo.dilationWidth;
|
|
var effectiveFilterDepth = convInfo.effectiveFilterDepth;
|
|
var effectiveFilterHeight = convInfo.effectiveFilterHeight;
|
|
var effectiveFilterWidth = convInfo.effectiveFilterWidth;
|
|
var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front;
|
|
var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top;
|
|
var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left;
|
|
var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1;
|
|
this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n ";
|
|
}
|
|
return MaxPool3DBackpropProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxPool3DGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, input = inputs.input;
|
|
var x = input;
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
|
|
var dilations = [1, 1, 1];
|
|
var convInfo = tf.backend_util.computePool3DInfo(x.shape, filterSize, strides, dilations, pad, dimRoundingMode);
|
|
var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', true /* get positions */);
|
|
var maxPool3dPositions = backend.runWebGLProgram(maxPool3dPositionsProgram, [x], x.dtype);
|
|
var maxPoolBackpropProgram = new MaxPool3DBackpropProgram(convInfo);
|
|
var result = backend.runWebGLProgram(maxPoolBackpropProgram, [dy, maxPool3dPositions], x.dtype);
|
|
backend.disposeIntermediateTensorInfo(maxPool3dPositions);
|
|
return result;
|
|
}
|
|
var maxPoolGrad3DConfig = {
|
|
kernelName: tf.MaxPool3DGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPool3DGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxPoolGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var dy = inputs.dy, input = inputs.input, output = inputs.output;
|
|
var x = input;
|
|
assertNotComplex([input, output], 'maxPoolGrad');
|
|
var filterSize = attrs.filterSize, strides = attrs.strides, pad = attrs.pad, dimRoundingMode = attrs.dimRoundingMode;
|
|
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, 1 /* dilations */, pad, dimRoundingMode);
|
|
var getPositions = true;
|
|
var maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions);
|
|
var maxPoolPositions = backend.runWebGLProgram(maxPoolPositionsProgram, [x], x.dtype);
|
|
var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo);
|
|
var result = backend.runWebGLProgram(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype);
|
|
backend.disposeIntermediateTensorInfo(maxPoolPositions);
|
|
return result;
|
|
}
|
|
var maxPoolGradConfig = {
|
|
kernelName: tf.MaxPoolGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: maxPoolGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, backend) {
|
|
var program = new Pool2DProgram(convInfo, 'max', false);
|
|
var poolOutput = backend.runWebGLProgram(program, [x], 'float32');
|
|
program = new Pool2DProgram(convInfo, 'max', true, true, includeBatchInIndex);
|
|
var indexOutput = backend.runWebGLProgram(program, [x], 'float32');
|
|
return [poolOutput, indexOutput];
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var maxPoolWithArgmaxConfig = {
|
|
kernelName: tf.MaxPoolWithArgmax,
|
|
backendName: 'webgl',
|
|
kernelFunc: function (_a) {
|
|
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
|
|
var x = inputs.x;
|
|
var _b = attrs, filterSize = _b.filterSize, strides = _b.strides, pad = _b.pad, includeBatchInIndex = _b.includeBatchInIndex;
|
|
var webglBackend = backend;
|
|
tf.util.assert(x.shape.length === 4, function () { return "Error in maxPool: input must be rank 4 but got rank " + x.shape.length + "."; });
|
|
var dilations = [1, 1];
|
|
tf.util.assert(tf.backend_util.eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' +
|
|
("Got strides " + strides + " and dilations '" + dilations + "'"); });
|
|
var convInfo = tf.backend_util.computePool2DInfo(x.shape, filterSize, strides, dilations, pad);
|
|
var _c = maxPoolWithArgmaxImpl(x, includeBatchInIndex, convInfo, webglBackend), result = _c[0], indexes = _c[1];
|
|
return [result, indexes];
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function meanImpl(x, reduceShape, outShape, backend) {
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var xSize = tf.util.sizeFromShape(x.shape);
|
|
var batchSize = xSize / inSize;
|
|
var reshapedInput = reshape({ inputs: { x: x }, attrs: { shape: [batchSize, inSize] }, backend: backend });
|
|
var reduced = reduce(reshapedInput, 'float32', 'mean', backend);
|
|
var reshapedOutput = reshape({ inputs: { x: reduced }, attrs: { shape: outShape }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(reshapedInput);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
return reshapedOutput;
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var meanConfig = {
|
|
kernelName: tf.Mean,
|
|
backendName: 'webgl',
|
|
kernelFunc: function (_a) {
|
|
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
|
|
var x = inputs.x;
|
|
var _b = attrs, keepDims = _b.keepDims, axis = _b.axis;
|
|
var webglBackend = backend;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(axis, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var meanInputIsTransposed = permutedAxes != null;
|
|
var shouldExecuteOnCPU = webglBackend.shouldExecuteOnCPU([x]);
|
|
var intermediates = [];
|
|
var meanInput = x;
|
|
if (meanInputIsTransposed) {
|
|
if (shouldExecuteOnCPU) {
|
|
var xTexData = webglBackend.texData.get(meanInput.dataId);
|
|
var values = xTexData.values;
|
|
var newShape = new Array(xRank);
|
|
for (var i = 0; i < newShape.length; i++) {
|
|
newShape[i] = x.shape[permutedAxes[i]];
|
|
}
|
|
var meanInputValues = transposeImplCPU(values, x.shape, x.dtype, permutedAxes, newShape);
|
|
meanInput = webglBackend.makeTensorInfo(newShape, x.dtype);
|
|
var meanInputData = webglBackend.texData.get(meanInput.dataId);
|
|
meanInputData.values = meanInputValues;
|
|
}
|
|
else {
|
|
meanInput = transposeImpl$1(x, permutedAxes, webglBackend);
|
|
}
|
|
intermediates.push(meanInput);
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('sum', axes, xRank);
|
|
var _c = tf.backend_util.computeOutAndReduceShapes(meanInput.shape, axes), meanOutShape = _c[0], reduceShape = _c[1];
|
|
var outShape = meanOutShape;
|
|
if (keepDims) {
|
|
// rather than reshape at the end, set the target shape here.
|
|
outShape = tf.backend_util.expandShapeToKeepDim(meanOutShape, origAxes);
|
|
}
|
|
var out = meanImpl(meanInput, reduceShape, outShape, webglBackend);
|
|
for (var _i = 0, intermediates_1 = intermediates; _i < intermediates_1.length; _i++) {
|
|
var i = intermediates_1[_i];
|
|
webglBackend.disposeIntermediateTensorInfo(i);
|
|
}
|
|
return out;
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function min(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, keepDims = attrs.keepDims;
|
|
var xRank = x.shape.length;
|
|
var origAxes = tf.util.parseAxisParam(axis, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, x.shape.length);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('min', axes, xRank);
|
|
var _a = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), outShape = _a[0], reduceShape = _a[1];
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
var reduced = reduce(a2D, a2D.dtype, 'min', backend);
|
|
var res;
|
|
if (keepDims) {
|
|
var newShape = tf.backend_util.expandShapeToKeepDim(outShape, origAxes);
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: newShape } });
|
|
}
|
|
else {
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } });
|
|
}
|
|
backend.disposeIntermediateTensorInfo(a2D);
|
|
backend.disposeIntermediateTensorInfo(reduced);
|
|
if (permutedAxes != null) {
|
|
backend.disposeIntermediateTensorInfo(permutedX);
|
|
}
|
|
return res;
|
|
}
|
|
var minConfig = {
|
|
kernelName: tf.Min,
|
|
backendName: 'webgl',
|
|
kernelFunc: min
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MINIMUM = CHECK_NAN_SNIPPET$1 + "\n return min(a, b);\n";
|
|
var MINIMUM_PACKED = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " +
|
|
CHECK_NAN_SNIPPET$2 + "\n return result;\n";
|
|
var minimum = binaryKernelFunc({
|
|
opSnippet: MINIMUM,
|
|
packedOpSnippet: MINIMUM_PACKED,
|
|
cpuKernelImpl: minimumImplCPU
|
|
});
|
|
var minimumConfig = {
|
|
kernelName: tf.Minimum,
|
|
backendName: 'webgl',
|
|
kernelFunc: minimum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MirrorPadProgram = /** @class */ (function () {
|
|
function MirrorPadProgram(xShape, paddings, mode) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
|
|
var rank = xShape.length;
|
|
var dtype = getCoordsDataType(rank);
|
|
var start = paddings.map(function (p) { return p[0]; }).join(',');
|
|
var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
|
|
var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
|
|
var offset = mode === 'reflect' ? 0 : 1;
|
|
if (rank === 1) {
|
|
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start) {\n outC = start * 2 - outC - " + offset + ";\n } else if(outC >= end) {\n outC = (end - 1) * 2 - outC + " + offset + ";\n }\n setOutput(getX(outC - start));\n }\n ";
|
|
return;
|
|
}
|
|
this.userCode = "\n " + dtype + " start = " + dtype + "(" + start + ");\n " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outC = getOutputCoords();\n for (int i = 0; i < " + rank + "; i++) {\n if (outC[i] < start[i]) {\n outC[i] = start[i] * 2 - outC[i] - " + offset + ";\n } else if(outC[i] >= end[i]) {\n outC[i] = (end[i] - 1) * 2 - outC[i] + " + offset + ";\n }\n }\n " + dtype + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n ";
|
|
}
|
|
return MirrorPadProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
/**
|
|
* Example shader code for
|
|
* `mirrorPad(tf.tensor1d([1, 2, 3], 'int32'), [[2, 2]], 'reflect')`
|
|
* ```
|
|
* const int start = int(2);
|
|
* const int end = int(5);
|
|
*
|
|
* void main() {
|
|
* int outputLoc = getOutputCoords();
|
|
* vec4 result = vec4(0.);
|
|
*
|
|
* int rc = outputLoc;
|
|
*
|
|
* int source = rc;
|
|
* if (source < start) {
|
|
* source = start * 2 - source - 0;
|
|
* } else if (source >= end) {
|
|
* source = (end - 1) * 2 - source + 0;
|
|
* }
|
|
* source -= start;
|
|
*
|
|
* result[0] = getChannel(getX(source), source);
|
|
* rc += 1;
|
|
* if(rc < 6) {
|
|
* int source = rc;
|
|
* if (source < start) {
|
|
* source = start * 2 - source - 0;
|
|
* } else if (source >= end) {
|
|
* source = (end - 1) * 2 - source + 0;
|
|
* }
|
|
* source -= start;
|
|
*
|
|
* result[1] = getChannel(getX(source), source);
|
|
* }
|
|
*
|
|
* setOutput(result);
|
|
* }
|
|
* ```
|
|
*/
|
|
var MirrorPadPackedProgram = /** @class */ (function () {
|
|
function MirrorPadPackedProgram(xShape, paddings, mode) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
|
|
var rank = xShape.length;
|
|
var dtype = getCoordsDataType(rank);
|
|
var start = paddings.map(function (p) { return p[0]; }).join(',');
|
|
var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
|
|
var coords = getChannels('rc', rank);
|
|
var source = getChannels('source', rank);
|
|
var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1];
|
|
var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")";
|
|
var offset = mode === 'reflect' ? 0 : 1;
|
|
var mainLoop = '';
|
|
if (rank === 1) {
|
|
var padSetup = "\n " + dtype + " source = rc;\n if (source < start) {\n source = start * 2 - source - " + offset + ";\n } else if (source >= end) {\n source = (end - 1) * 2 - source + " + offset + ";\n }\n source -= start;\n ";
|
|
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
|
|
}
|
|
else {
|
|
var padSetup = "\n " + dtype + " source = rc;\n " + dtype + " lt = " + dtype + "(lessThan(source, start));\n " + dtype + " gte = " + dtype + "(greaterThanEqual(source, end));\n " + dtype + " orig = 1 - (lt + gte);\n source = orig * source +\n lt * (start * 2 - source - " + offset + ") +\n gte * ((end - 1) * 2 - source + " + offset + ");\n source -= start;\n ";
|
|
mainLoop = "\n " + dtype + " rc = outputLoc;\n " + padSetup + "\n result[0] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[1] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {\n " + padSetup + "\n result[2] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n " + padSetup + "\n result[3] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n }\n ";
|
|
}
|
|
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
|
|
}
|
|
return MirrorPadPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var mirrorPadKernelFunc = function (_a) {
|
|
var inputs = _a.inputs, backend = _a.backend, attrs = _a.attrs;
|
|
var x = inputs.x;
|
|
var paddings = attrs.paddings, mode = attrs.mode;
|
|
var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new MirrorPadPackedProgram(x.shape, paddings, mode) :
|
|
new MirrorPadProgram(x.shape, paddings, mode);
|
|
var output = backend.runWebGLProgram(program, [x], x.dtype);
|
|
return output;
|
|
};
|
|
var mirrorPadConfig = {
|
|
kernelName: tf.MirrorPad,
|
|
backendName: 'webgl',
|
|
kernelFunc: mirrorPadKernelFunc,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);";
|
|
var MOD_PACKED = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " +
|
|
CHECK_NAN_SNIPPET$2 + "\n return result;\n";
|
|
var mod = binaryKernelFunc({
|
|
opSnippet: MOD,
|
|
packedOpSnippet: MOD_PACKED,
|
|
});
|
|
var modConfig = {
|
|
kernelName: tf.Mod,
|
|
backendName: 'webgl',
|
|
kernelFunc: mod
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var MultinomialProgram = /** @class */ (function () {
|
|
function MultinomialProgram(batchSize, numOutcomes, numSamples) {
|
|
this.variableNames = ['probs'];
|
|
this.outputShape = [batchSize, numSamples];
|
|
this.userCode = "\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n ";
|
|
}
|
|
MultinomialProgram.prototype.getCustomSetupFunc = function (seed) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.seedLoc == null) {
|
|
_this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.seedLoc, seed);
|
|
};
|
|
};
|
|
return MultinomialProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// Without the equality check div produces 0.9999 for a = b, which when
|
|
// floored can cause errors.
|
|
var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;";
|
|
// We do the same as in ./binaryop_gpu, with vec4 and ivec4.
|
|
// On Linux, the vectorized implementation produces NaNs when a and b are 0.
|
|
var DIV_PACKED = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n";
|
|
var realDiv = binaryKernelFunc({ opSnippet: DIV, packedOpSnippet: DIV_PACKED, checkOutOfBounds: true });
|
|
var realDivConfig = {
|
|
kernelName: tf.RealDiv,
|
|
backendName: 'webgl',
|
|
kernelFunc: realDiv,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SUB = 'return a - b;';
|
|
var sub = binaryKernelFunc({
|
|
opSnippet: SUB,
|
|
packedOpSnippet: SUB,
|
|
supportsComplex: true,
|
|
cpuKernelImpl: subImplCPU
|
|
});
|
|
var subConfig = {
|
|
kernelName: tf.Sub,
|
|
backendName: 'webgl',
|
|
kernelFunc: sub
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function softmax(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var logits = inputs.logits;
|
|
var dim = attrs.dim;
|
|
var axes = tf.util.parseAxisParam([dim], logits.shape);
|
|
var maxLogit = max({
|
|
inputs: { x: logits },
|
|
backend: backend,
|
|
attrs: { reductionIndices: axes, keepDims: false }
|
|
});
|
|
var expandedShape = tf.backend_util.expandShapeToKeepDim(maxLogit.shape, axes);
|
|
var maxLogitsReshaped = reshape({ inputs: { x: maxLogit }, backend: backend, attrs: { shape: expandedShape } });
|
|
var a = sub({ inputs: { a: logits, b: maxLogitsReshaped }, backend: backend });
|
|
var b = exp({ inputs: { x: a }, backend: backend });
|
|
var sumExp = sum({ inputs: { x: b }, backend: backend, attrs: { axis: axes, keepDims: false } });
|
|
var sumExpReshaped = reshape({ inputs: { x: sumExp }, backend: backend, attrs: { shape: expandedShape } });
|
|
var res = realDiv({ inputs: { a: b, b: sumExpReshaped }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(maxLogit);
|
|
backend.disposeIntermediateTensorInfo(maxLogitsReshaped);
|
|
backend.disposeIntermediateTensorInfo(a);
|
|
backend.disposeIntermediateTensorInfo(b);
|
|
backend.disposeIntermediateTensorInfo(sumExp);
|
|
backend.disposeIntermediateTensorInfo(sumExpReshaped);
|
|
return res;
|
|
}
|
|
var softmaxConfig = {
|
|
kernelName: tf.Softmax,
|
|
backendName: 'webgl',
|
|
kernelFunc: softmax
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function multinomial(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var logits = inputs.logits;
|
|
var numSamples = attrs.numSamples, seed = attrs.seed, normalized = attrs.normalized;
|
|
var probs = normalized ?
|
|
logits :
|
|
softmax({ inputs: { logits: logits }, backend: backend, attrs: { dim: logits.shape.length - 1 } });
|
|
var batchSize = probs.shape[0];
|
|
var numOutcomes = probs.shape[1];
|
|
var program = new MultinomialProgram(batchSize, numOutcomes, numSamples);
|
|
var customSetup = program.getCustomSetupFunc(seed);
|
|
var res = backend.runWebGLProgram(program, [probs], 'int32', customSetup);
|
|
if (!normalized) {
|
|
backend.disposeIntermediateTensorInfo(probs);
|
|
}
|
|
return res;
|
|
}
|
|
var multinomialConfig = {
|
|
kernelName: tf.Multinomial,
|
|
backendName: 'webgl',
|
|
kernelFunc: multinomial
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var NEG = "return -x;";
|
|
// This doesn't use unaryKernelFunc because negImplCPU is not of type
|
|
// SimpleUnaryKernelImplCPU.
|
|
function neg(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
if (backend.shouldExecuteOnCPU([x])) {
|
|
var xData = backend.texData.get(x.dataId);
|
|
var _a = negImplCPU(xData.values, x.shape, x.dtype), outValues = _a[0], newShape = _a[1];
|
|
return backend.makeTensorInfo(newShape, x.dtype, outValues);
|
|
}
|
|
var program;
|
|
if (tf.env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) {
|
|
program = new UnaryOpPackedProgram(x.shape, NEG);
|
|
}
|
|
else {
|
|
program = new UnaryOpProgram(x.shape, NEG);
|
|
}
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
var negConfig = {
|
|
kernelName: tf.Neg,
|
|
backendName: 'webgl',
|
|
kernelFunc: neg
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var nonMaxSuppressionV3Impl = tf.kernel_impls.nonMaxSuppressionV3Impl;
|
|
function nonMaxSuppressionV3(args) {
|
|
tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var boxes = inputs.boxes, scores = inputs.scores;
|
|
var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold;
|
|
var boxesVals = backend.readSync(boxes.dataId);
|
|
var scoresVals = backend.readSync(scores.dataId);
|
|
var selectedIndices = nonMaxSuppressionV3Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold).selectedIndices;
|
|
return backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices));
|
|
}
|
|
var nonMaxSuppressionV3Config = {
|
|
kernelName: tf.NonMaxSuppressionV3,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV3
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var nonMaxSuppressionV4Impl = tf.kernel_impls.nonMaxSuppressionV4Impl;
|
|
function nonMaxSuppressionV4(args) {
|
|
tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var boxes = inputs.boxes, scores = inputs.scores;
|
|
var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold, padToMaxOutputSize = attrs.padToMaxOutputSize;
|
|
var boxesVals = backend.readSync(boxes.dataId);
|
|
var scoresVals = backend.readSync(scores.dataId);
|
|
var _a = nonMaxSuppressionV4Impl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold, padToMaxOutputSize), selectedIndices = _a.selectedIndices, validOutputs = _a.validOutputs;
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([], 'int32', new Int32Array([validOutputs]))
|
|
];
|
|
}
|
|
var nonMaxSuppressionV4Config = {
|
|
kernelName: tf.NonMaxSuppressionV4,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV4
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var nonMaxSuppressionV5Impl = tf.kernel_impls.nonMaxSuppressionV5Impl;
|
|
function nonMaxSuppressionV5(args) {
|
|
tf.backend_util.warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' +
|
|
'Call tf.nonMaxSuppressionAsync() instead');
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var boxes = inputs.boxes, scores = inputs.scores;
|
|
var maxOutputSize = attrs.maxOutputSize, iouThreshold = attrs.iouThreshold, scoreThreshold = attrs.scoreThreshold, softNmsSigma = attrs.softNmsSigma;
|
|
var boxesVals = backend.readSync(boxes.dataId);
|
|
var scoresVals = backend.readSync(scores.dataId);
|
|
var maxOutputSizeVal = maxOutputSize;
|
|
var iouThresholdVal = iouThreshold;
|
|
var scoreThresholdVal = scoreThreshold;
|
|
var softNmsSigmaVal = softNmsSigma;
|
|
var _a = nonMaxSuppressionV5Impl(boxesVals, scoresVals, maxOutputSizeVal, iouThresholdVal, scoreThresholdVal, softNmsSigmaVal), selectedIndices = _a.selectedIndices, selectedScores = _a.selectedScores;
|
|
return [
|
|
backend.makeTensorInfo([selectedIndices.length], 'int32', new Int32Array(selectedIndices)),
|
|
backend.makeTensorInfo([selectedScores.length], 'float32', new Float32Array(selectedScores))
|
|
];
|
|
}
|
|
var nonMaxSuppressionV5Config = {
|
|
kernelName: tf.NonMaxSuppressionV5,
|
|
backendName: 'webgl',
|
|
kernelFunc: nonMaxSuppressionV5
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var OneHotProgram = /** @class */ (function () {
|
|
function OneHotProgram(numIndices, depth, onValue, offValue) {
|
|
this.variableNames = ['indices'];
|
|
this.outputShape = [numIndices, depth];
|
|
this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n ";
|
|
}
|
|
return OneHotProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var oneHot = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var indices = inputs.indices;
|
|
var depth = attrs.depth, onValue = attrs.onValue, offValue = attrs.offValue;
|
|
var indicesSize = tf.util.sizeFromShape(indices.shape);
|
|
var program = new OneHotProgram(indicesSize, depth, onValue, offValue);
|
|
var reshaped = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [indicesSize] } });
|
|
var result = backend.runWebGLProgram(program, [reshaped], indices.dtype);
|
|
backend.disposeIntermediateTensorInfo(reshaped);
|
|
var outShape = indices.shape.concat([depth]);
|
|
var out = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return out;
|
|
};
|
|
var oneHotConfig = {
|
|
kernelName: tf.OneHot,
|
|
backendName: 'webgl',
|
|
kernelFunc: oneHot
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function zerosLike(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
if (x.dtype === 'complex64') {
|
|
var realPart = real({ inputs: { input: x }, backend: backend });
|
|
var r = zerosLike({ inputs: { x: realPart }, backend: backend });
|
|
var imagPart = imag({ inputs: { input: x }, backend: backend });
|
|
var i = zerosLike({ inputs: { x: imagPart }, backend: backend });
|
|
var result = complex({ inputs: { real: r, imag: i }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
return fill({
|
|
attrs: {
|
|
shape: x.shape,
|
|
dtype: x.dtype,
|
|
value: x.dtype === 'string' ? '' : 0
|
|
},
|
|
backend: backend
|
|
});
|
|
}
|
|
}
|
|
var zerosLikeConfig = {
|
|
kernelName: tf.ZerosLike,
|
|
backendName: 'webgl',
|
|
kernelFunc: zerosLike
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function onesLike(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var x = inputs.x;
|
|
if (x.dtype === 'string') {
|
|
throw new Error('onesLike is not supported under string dtype');
|
|
}
|
|
else if (x.dtype === 'complex64') {
|
|
var realPart = real({ inputs: { input: x }, backend: backend });
|
|
var r = onesLike({ inputs: { x: realPart }, backend: backend });
|
|
var imagPart = imag({ inputs: { input: x }, backend: backend });
|
|
var i = zerosLike({ inputs: { x: imagPart }, backend: backend });
|
|
var result = complex({ inputs: { real: r, imag: i }, backend: backend });
|
|
backend.disposeIntermediateTensorInfo(realPart);
|
|
backend.disposeIntermediateTensorInfo(r);
|
|
backend.disposeIntermediateTensorInfo(imagPart);
|
|
backend.disposeIntermediateTensorInfo(i);
|
|
return result;
|
|
}
|
|
else {
|
|
// TODO(cais, smilkov): Add WebGL shader for onesLike:
|
|
// https://github.com/tensorflow/tfjs/issues/1293
|
|
return fill({ attrs: { shape: x.shape, dtype: x.dtype, value: 1 }, backend: backend });
|
|
}
|
|
}
|
|
var onesLikeConfig = {
|
|
kernelName: tf.OnesLike,
|
|
backendName: 'webgl',
|
|
kernelFunc: onesLike
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function pack(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var axis = attrs.axis;
|
|
if (inputs.length === 1) {
|
|
return expandDims({ inputs: { input: inputs[0] }, backend: backend, attrs: { dim: axis } });
|
|
}
|
|
var shape = inputs[0].shape;
|
|
var dtype = inputs[0].dtype;
|
|
inputs.forEach(function (t) {
|
|
tf.util.assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes');
|
|
tf.util.assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; });
|
|
});
|
|
var intermediateTensorInfos = [];
|
|
var expandedTensors = inputs.map(function (t) {
|
|
var expandedT = expandDims({ inputs: { input: t }, backend: backend, attrs: { dim: axis } });
|
|
intermediateTensorInfos.push(expandedT);
|
|
return expandedT;
|
|
});
|
|
var result = concat({ inputs: expandedTensors, backend: backend, attrs: { axis: axis } });
|
|
intermediateTensorInfos.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return result;
|
|
}
|
|
var packConfig = {
|
|
kernelName: tf.Pack,
|
|
backendName: 'webgl',
|
|
kernelFunc: pack
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PadProgram = /** @class */ (function () {
|
|
function PadProgram(xShape, paddings, constantValue) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
|
|
var rank = xShape.length;
|
|
var type = getCoordsDataType(rank);
|
|
var start = paddings.map(function (p) { return p[0]; }).join(',');
|
|
var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
|
|
var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank);
|
|
if (rank === 1) {
|
|
this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n uniform float value;\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(value);\n } else {\n setOutput(getX(outC - start));\n }\n }\n ";
|
|
return;
|
|
}
|
|
this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n uniform float value;\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(value);\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n ";
|
|
}
|
|
PadProgram.prototype.getCustomSetupFunc = function (value) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.valueLoc == null) {
|
|
_this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.valueLoc, value);
|
|
};
|
|
};
|
|
return PadProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var PadPackedProgram = /** @class */ (function () {
|
|
function PadPackedProgram(xShape, paddings, constantValue) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */);
|
|
var rank = xShape.length;
|
|
var dtype = getCoordsDataType(rank);
|
|
var start = paddings.map(function (p) { return p[0]; }).join(',');
|
|
var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(',');
|
|
var coords = getChannels('rc', rank);
|
|
var source = getChannels('source', rank);
|
|
var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1];
|
|
var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")";
|
|
var componentSetup = [
|
|
dtype + " rc = outputLoc;", coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ",
|
|
rank === 1 ? '' : "}\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {",
|
|
rank === 1 ? '' : " " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {"
|
|
];
|
|
var paddingArea = rank === 1 ?
|
|
'rc < start || rc >= end' :
|
|
'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))';
|
|
var mainLoop = '';
|
|
for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) {
|
|
mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(value);\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n ";
|
|
}
|
|
mainLoop += (rank === 1 ? "} " : "}}");
|
|
this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n uniform float value;\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n ";
|
|
}
|
|
PadPackedProgram.prototype.getCustomSetupFunc = function (value) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.valueLoc == null) {
|
|
_this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value');
|
|
}
|
|
gpgpu.gl.uniform1f(_this.valueLoc, value);
|
|
};
|
|
};
|
|
return PadPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var padV2 = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var paddings = attrs.paddings, constantValue = attrs.constantValue;
|
|
var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new PadPackedProgram(x.shape, paddings, constantValue) :
|
|
new PadProgram(x.shape, paddings, constantValue);
|
|
var customSetup = program.getCustomSetupFunc(constantValue);
|
|
return backend.runWebGLProgram(program, [x], x.dtype, customSetup);
|
|
};
|
|
var padV2Config = {
|
|
kernelName: tf.PadV2,
|
|
backendName: 'webgl',
|
|
kernelFunc: padV2
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var POW = "\n if(a < 0.0 && floor(b) < b){\n return NAN;\n }\n if (b == 0.0) {\n return 1.0;\n }\n return (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n";
|
|
var POW_PACKED = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " +
|
|
CHECK_NAN_SNIPPET$2 + "\n return result;\n";
|
|
var pow = binaryKernelFunc({ opSnippet: POW, packedOpSnippet: POW_PACKED });
|
|
var powConfig = {
|
|
kernelName: tf.Pow,
|
|
backendName: 'webgl',
|
|
kernelFunc: pow
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function prod(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var axis = attrs.axis, keepDims = attrs.keepDims;
|
|
var xRank = x.shape.length;
|
|
var toDispose = [];
|
|
var origAxes = tf.util.parseAxisParam(axis, x.shape);
|
|
var axes = origAxes;
|
|
var permutedAxes = tf.backend_util.getAxesPermutation(axes, xRank);
|
|
var permutedX = x;
|
|
if (permutedAxes != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutedAxes } });
|
|
axes = tf.backend_util.getInnerMostAxes(axes.length, xRank);
|
|
toDispose.push(permutedX);
|
|
}
|
|
tf.backend_util.assertAxesAreInnerMostDims('prod', axes, xRank);
|
|
var res;
|
|
if (backend.shouldExecuteOnCPU([permutedX])) {
|
|
var xVals = backend.texData.get(permutedX.dataId).values;
|
|
var _a = prodImplCPU(permutedX.shape, permutedX.dtype, xVals, axes), outVals = _a.outVals, outShape = _a.outShape, outDtype = _a.outDtype;
|
|
res = backend.makeTensorInfo(outShape, outDtype, outVals);
|
|
}
|
|
else {
|
|
var _b = tf.backend_util.computeOutAndReduceShapes(permutedX.shape, axes), outShape = _b[0], reduceShape = _b[1];
|
|
var inSize = tf.util.sizeFromShape(reduceShape);
|
|
var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
var outputDType = tf.sumOutType(x.dtype);
|
|
var reduced = reduce(a2D, outputDType, 'prod', backend);
|
|
res = reshape({ inputs: { x: reduced }, backend: backend, attrs: { shape: outShape } });
|
|
toDispose.push(a2D);
|
|
toDispose.push(reduced);
|
|
}
|
|
if (keepDims) {
|
|
toDispose.push(res);
|
|
var newShape = tf.backend_util.expandShapeToKeepDim(res.shape, origAxes);
|
|
res = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: newShape } });
|
|
}
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return res;
|
|
}
|
|
var prodConfig = {
|
|
kernelName: tf.Prod,
|
|
backendName: 'webgl',
|
|
kernelFunc: prod
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var range = function (args) {
|
|
var backend = args.backend, attrs = args.attrs;
|
|
var start = attrs.start, stop = attrs.stop, step = attrs.step, dtype = attrs.dtype;
|
|
var values = rangeImplCPU(start, stop, step, dtype);
|
|
return backend.makeTensorInfo([values.length], dtype, values);
|
|
};
|
|
var rangeConfig = {
|
|
kernelName: tf.Range,
|
|
backendName: 'webgl',
|
|
kernelFunc: range
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var RECIPROCAL = "return 1.0 / x;";
|
|
var reciprocal = unaryKernelFunc({ opSnippet: RECIPROCAL });
|
|
var reciprocalConfig = {
|
|
kernelName: tf.Reciprocal,
|
|
backendName: 'webgl',
|
|
kernelFunc: reciprocal,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var RELU$2 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : x;\n";
|
|
var RELU_PACKED = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
|
|
var relu = unaryKernelFunc({ opSnippet: RELU$2, packedOpSnippet: RELU_PACKED });
|
|
var reluConfig = {
|
|
kernelName: tf.Relu,
|
|
backendName: 'webgl',
|
|
kernelFunc: relu
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var RELU6$2 = CHECK_NAN_SNIPPET + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n";
|
|
var RELU6_PACKED = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n";
|
|
var relu6 = unaryKernelFunc({ opSnippet: RELU6$2, packedOpSnippet: RELU6_PACKED });
|
|
var relu6Config = {
|
|
kernelName: tf.Relu6,
|
|
backendName: 'webgl',
|
|
kernelFunc: relu6
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeBilinearProgram = /** @class */ (function () {
|
|
function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = [];
|
|
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
var effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
var effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
var sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC =
|
|
"(vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" +
|
|
" - vec2(0.5)";
|
|
}
|
|
else {
|
|
sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
|
|
}
|
|
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(max(sourceFracIndexRC, vec2(0.0)));\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n ";
|
|
}
|
|
return ResizeBilinearProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeBilinearPackedProgram = /** @class */ (function () {
|
|
function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
var effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
var effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
var sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC = "(vec3(yRC) + vec3(0.5)) * " +
|
|
"effectiveInputOverOutputRatioRC - vec3(0.5)";
|
|
}
|
|
else {
|
|
sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
|
|
}
|
|
this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(max(sourceFracIndexRC, vec3(0.0)));\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n ";
|
|
}
|
|
return ResizeBilinearPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function resizeBilinear(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var images = inputs.images;
|
|
var alignCorners = attrs.alignCorners, halfPixelCenters = attrs.halfPixelCenters, size = attrs.size;
|
|
var newHeight = size[0], newWidth = size[1];
|
|
var program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
|
|
new ResizeBilinearPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
|
|
new ResizeBilinearProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
|
|
return backend.runWebGLProgram(program, [images], 'float32');
|
|
}
|
|
var resizeBilinearConfig = {
|
|
kernelName: tf.ResizeBilinear,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeBilinear
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeBilinearBackpropProgram = /** @class */ (function () {
|
|
function ResizeBilinearBackpropProgram(dyShape, inputShape, alignCorners) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
var xHeight = inputShape[1], xWidth = inputShape[2];
|
|
var yHeight = dyShape[1], yWidth = dyShape[2];
|
|
// In the backwards pass, we want to find the pixels that were generated for
|
|
// each pixel in the input image the forward pass and add the corresponding
|
|
// coefficient from dy to the gradient (with some interpolation).
|
|
var effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
var effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
var heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
var widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
var invHeightScale = 1 / heightScale;
|
|
var invWidthScale = 1 / widthScale;
|
|
// This defines the size of the window of values around a particular
|
|
// index in dy that we want to search for contributions to dx.
|
|
var winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
|
var winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
|
|
}
|
|
return ResizeBilinearBackpropProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function resizeBilinearGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var images = inputs.images, dy = inputs.dy;
|
|
var alignCorners = attrs.alignCorners;
|
|
var program = new ResizeBilinearBackpropProgram(dy.shape, images.shape, alignCorners);
|
|
return backend.runWebGLProgram(program, [dy], dy.dtype);
|
|
}
|
|
var resizeBilinearGradConfig = {
|
|
kernelName: tf.ResizeBilinearGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeBilinearGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeNearestNeighborProgram = /** @class */ (function () {
|
|
function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.outputShape = [];
|
|
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
var effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
var effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
// When align corners is false, we rounds the value with floor.
|
|
var roundBase = alignCorners ? '0.5' : '0.0';
|
|
var sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC =
|
|
"max((vec2(yRC) + vec2(0.5)) * effectiveInputOverOutputRatioRC" +
|
|
", vec2(0.0))";
|
|
}
|
|
else {
|
|
sourceFracIndexRC = "vec2(yRC) * effectiveInputOverOutputRatioRC";
|
|
}
|
|
this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n ";
|
|
}
|
|
return ResizeNearestNeighborProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeNearestNeighborPackedProgram = /** @class */ (function () {
|
|
function ResizeNearestNeighborPackedProgram(inputShape, newHeight, newWidth, alignCorners, halfPixelCenters) {
|
|
this.variableNames = ['A'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
this.outputShape = [];
|
|
var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3];
|
|
this.outputShape = [batch, newHeight, newWidth, depth];
|
|
var effectiveInSize = [
|
|
(alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight,
|
|
(alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth
|
|
];
|
|
var effectiveOutSize = [
|
|
(alignCorners && newHeight > 1) ? newHeight - 1 : newHeight,
|
|
(alignCorners && newWidth > 1) ? newWidth - 1 : newWidth
|
|
];
|
|
// When align corners is false, we rounds the value with floor.
|
|
var roundBase = alignCorners ? '0.5' : '0.0';
|
|
var sourceFracIndexRC;
|
|
if (halfPixelCenters) {
|
|
sourceFracIndexRC = "max((vec3(yRC) + vec3(0.5)) * " +
|
|
"effectiveInputOverOutputRatioRC, vec3(0.0))";
|
|
}
|
|
else {
|
|
sourceFracIndexRC = "vec3(yRC) * effectiveInputOverOutputRatioRC";
|
|
}
|
|
this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = " + sourceFracIndexRC + ";\n\n // Compute the coordinators of nearest neighbor point.\n ivec3 sourceNearestRC = ivec3(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n vec4 newValue = vec4(\n getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d),\n hasNextCol ? getAValue(b, sourceNearestRC.x, sourceNearestRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceNearestRC.x, sourceNearestRC.z, d + 1) : 0.0);\n\n setOutput(newValue);\n }\n ";
|
|
}
|
|
return ResizeNearestNeighborPackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function resizeNearestNeighbor(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var images = inputs.images;
|
|
var alignCorners = attrs.alignCorners, halfPixelCenters = attrs.halfPixelCenters, size = attrs.size;
|
|
var newHeight = size[0], newWidth = size[1];
|
|
var program = tf.env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ?
|
|
new ResizeNearestNeighborPackedProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters) :
|
|
new ResizeNearestNeighborProgram(images.shape, newHeight, newWidth, alignCorners, halfPixelCenters);
|
|
return backend.runWebGLProgram(program, [images], images.dtype);
|
|
}
|
|
var resizeNearestNeighborConfig = {
|
|
kernelName: tf.ResizeNearestNeighbor,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeNearestNeighbor
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ResizeNearestNeigborBackpropProgram = /** @class */ (function () {
|
|
function ResizeNearestNeigborBackpropProgram(dyShape, inputShape, alignCorners) {
|
|
this.variableNames = ['dy'];
|
|
this.outputShape = [];
|
|
this.outputShape = inputShape;
|
|
var xHeight = inputShape[1], xWidth = inputShape[2];
|
|
var yHeight = dyShape[1], yWidth = dyShape[2];
|
|
// In the backwards pass, we want to find the pixels that were generated for
|
|
// each pixel in the input image the forward pass and add the corresponding
|
|
// coefficient from dy to the gradient (with some interpolation).
|
|
var effectiveXSize = [
|
|
(alignCorners && yHeight > 1) ? xHeight - 1 : xHeight,
|
|
(alignCorners && yWidth > 1) ? xWidth - 1 : xWidth
|
|
];
|
|
var effectiveYSize = [
|
|
(alignCorners && yHeight > 1) ? yHeight - 1 : yHeight,
|
|
(alignCorners && yWidth > 1) ? yWidth - 1 : yWidth
|
|
];
|
|
var heightScale = effectiveXSize[0] / effectiveYSize[0];
|
|
var widthScale = effectiveXSize[1] / effectiveYSize[1];
|
|
var invHeightScale = 1 / heightScale;
|
|
var invWidthScale = 1 / widthScale;
|
|
// This defines the size of the window of values around a particular
|
|
// index in dy that we want to search for contributions to dx.
|
|
var winHeight = (Math.ceil(invHeightScale) * 2) + 2;
|
|
var winWidth = (Math.ceil(invWidthScale) * 2) + 2;
|
|
this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n ";
|
|
}
|
|
return ResizeNearestNeigborBackpropProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function resizeNearestNeighborGrad(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var images = inputs.images, dy = inputs.dy;
|
|
var alignCorners = attrs.alignCorners;
|
|
var program = new ResizeNearestNeigborBackpropProgram(dy.shape, images.shape, alignCorners);
|
|
return backend.runWebGLProgram(program, [dy], dy.dtype);
|
|
}
|
|
var resizeNearestNeighborGradConfig = {
|
|
kernelName: tf.ResizeNearestNeighborGrad,
|
|
backendName: 'webgl',
|
|
kernelFunc: resizeNearestNeighborGrad
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ReverseProgram = /** @class */ (function () {
|
|
function ReverseProgram(xShape, axis) {
|
|
this.variableNames = ['x'];
|
|
var rank = xShape.length;
|
|
if (rank > 4) {
|
|
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
|
|
}
|
|
this.outputShape = xShape;
|
|
if (rank === 1) {
|
|
this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n ";
|
|
return;
|
|
}
|
|
var getInCoord = function (i) {
|
|
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
|
|
return xShape[i] + " - coords[" + i + "] - 1";
|
|
}
|
|
return "coords[" + i + "]";
|
|
};
|
|
var inCoords = xShape.map(function (_, i) { return getInCoord(i); }).join(',');
|
|
var type = getCoordsDataType(rank);
|
|
this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n ";
|
|
}
|
|
return ReverseProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ReversePackedProgram = /** @class */ (function () {
|
|
function ReversePackedProgram(xShape, axis) {
|
|
this.variableNames = ['x'];
|
|
this.packedInputs = true;
|
|
this.packedOutput = true;
|
|
var rank = xShape.length;
|
|
if (rank > 4) {
|
|
throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported");
|
|
}
|
|
this.outputShape = xShape;
|
|
var channels = getChannels('rc', rank);
|
|
var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1];
|
|
var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2];
|
|
var type = getCoordsDataType(rank);
|
|
if (rank === 1) {
|
|
this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n ";
|
|
}
|
|
else {
|
|
this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n ";
|
|
}
|
|
function getR(channels) {
|
|
return getChannel(channels);
|
|
}
|
|
function getG(channels) {
|
|
channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
|
|
return getChannel(channels);
|
|
}
|
|
function getB(channels) {
|
|
channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
|
|
return getChannel(channels);
|
|
}
|
|
function getA(channels) {
|
|
channels[rank - 1] = '(' + channels[rank - 1] + " + 1)";
|
|
channels[rank - 2] = '(' + channels[rank - 2] + " + 1)";
|
|
return getChannel(channels);
|
|
}
|
|
function getChannel(channels) {
|
|
var inCoordsArray = xShape.map(function (_, i) { return getInCoord(i, channels); });
|
|
var inCoords = inCoordsArray.join(',');
|
|
var innerDims = inCoordsArray.slice(-2).join(',');
|
|
return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))";
|
|
}
|
|
function getInCoord(i, channels1) {
|
|
if (axis.indexOf(i) !== -1 && xShape[i] !== 1) {
|
|
return xShape[i] + " - " + channels1[i] + " - 1";
|
|
}
|
|
else {
|
|
return "" + channels1[i];
|
|
}
|
|
}
|
|
}
|
|
return ReversePackedProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function reverse(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var dims = attrs.dims;
|
|
var xRank = x.shape.length;
|
|
var $dims = tf.util.parseAxisParam(dims, x.shape);
|
|
if (xRank === 0) {
|
|
return identity({ inputs: { x: x }, backend: backend });
|
|
}
|
|
var program = tf.env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ?
|
|
new ReversePackedProgram(x.shape, $dims) :
|
|
new ReverseProgram(x.shape, $dims);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
var reverseConfig = {
|
|
kernelName: tf.Reverse,
|
|
backendName: 'webgl',
|
|
kernelFunc: reverse
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var RotateProgram = /** @class */ (function () {
|
|
function RotateProgram(imageShape, fillValue) {
|
|
this.variableNames = ['Image'];
|
|
this.outputShape = [];
|
|
var imageHeight = imageShape[1];
|
|
var imageWidth = imageShape[2];
|
|
this.outputShape = imageShape;
|
|
var fillSnippet = '';
|
|
if (typeof fillValue === 'number') {
|
|
fillSnippet = "float outputValue = " + fillValue.toFixed(2) + ";";
|
|
}
|
|
else {
|
|
fillSnippet = "\n vec3 fill = vec3(" + fillValue.join(',') + ");\n float outputValue = fill[coords[3]];";
|
|
}
|
|
this.userCode = "\n uniform vec4 params;\n void main() {\n ivec4 coords = getOutputCoords();\n int x = coords[2];\n int y = coords[1];\n float coordXFloat = (float(x) - params[0]) * params[3] -\n (float(y) - params[1]) * params[2];\n float coordYFloat = (float(x) - params[0]) * params[2] +\n (float(y) - params[1]) * params[3];\n int coordX = int(round(coordXFloat + params[0]));\n int coordY = int(round(coordYFloat + params[1]));\n " + fillSnippet + "\n if(coordX >= 0 && coordX < " + imageWidth + " && coordY >= 0 && coordY < " + imageHeight + ") {\n outputValue = getImage(coords[0], coordY, coordX, coords[3]);\n }\n setOutput(outputValue);\n }\n ";
|
|
}
|
|
RotateProgram.prototype.getCustomSetupFunc = function (centerX, centerY, sinFactor, cosFactor) {
|
|
var _this = this;
|
|
return function (gpgpu, webGLProgram) {
|
|
if (_this.paramsLoc == null) {
|
|
_this.paramsLoc =
|
|
gpgpu.getUniformLocationNoThrow(webGLProgram, 'params');
|
|
}
|
|
gpgpu.gl.uniform4f(_this.paramsLoc, centerX, centerY, sinFactor, cosFactor);
|
|
};
|
|
};
|
|
return RotateProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var rotateWithOffsetConfig = {
|
|
kernelName: tf.RotateWithOffset,
|
|
backendName: 'webgl',
|
|
kernelFunc: function (_a) {
|
|
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
|
|
var image = inputs.image;
|
|
var _b = attrs, radians = _b.radians, fillValue = _b.fillValue, center = _b.center;
|
|
var webglBackend = backend;
|
|
var program = new RotateProgram(image.shape, fillValue);
|
|
var _c = tf.backend_util.getImageCenter(center, image.shape[1], image.shape[2]), centerX = _c[0], centerY = _c[1];
|
|
var customSetup = program.getCustomSetupFunc(centerX, centerY, Math.sin(radians), Math.cos(radians));
|
|
var output = webglBackend.runWebGLProgram(program, [image], image.dtype, customSetup);
|
|
return output;
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n";
|
|
var round = unaryKernelFunc({ opSnippet: ROUND });
|
|
var roundConfig = {
|
|
kernelName: tf.Round,
|
|
backendName: 'webgl',
|
|
kernelFunc: round,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var RSQRT = "return inversesqrt(x);";
|
|
var rsqrt = unaryKernelFunc({ opSnippet: RSQRT, cpuKernelImpl: rsqrtImplCPU });
|
|
var rsqrtConfig = {
|
|
kernelName: tf.Rsqrt,
|
|
backendName: 'webgl',
|
|
kernelFunc: rsqrt
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var ScatterProgram = /** @class */ (function () {
|
|
function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) {
|
|
this.variableNames = ['updates', 'indices', 'defaultValue'];
|
|
this.outputShape = shape;
|
|
var stridesType = getCoordsDataType(strides.length);
|
|
var dtype = getCoordsDataType(shape.length);
|
|
var indicesString = '';
|
|
if (indicesRank === 1) {
|
|
indicesString = 'i';
|
|
}
|
|
else if (indicesRank === 2) {
|
|
indicesString = 'i, j';
|
|
}
|
|
var indicesSnippet = "getIndices(" + indicesString + ")";
|
|
var updatesString = '';
|
|
if (updatesRank === 1) {
|
|
updatesString = 'i';
|
|
}
|
|
else if (updatesRank === 2) {
|
|
updatesString = 'i, coords[1]';
|
|
}
|
|
var updatesSnippet = "getUpdates(" + updatesString + ")";
|
|
var strideString = sliceDim > 1 ? 'strides[j]' : 'strides';
|
|
this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n ";
|
|
}
|
|
return ScatterProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function scatterNd(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var indices = inputs.indices, updates = inputs.updates;
|
|
var shape = attrs.shape;
|
|
var _a = tf.backend_util.calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize;
|
|
var flattenShape = [outputSize / sliceSize, sliceSize];
|
|
if (outputSize === 0) {
|
|
return backend.makeTensorInfo(shape, indices.dtype);
|
|
}
|
|
var flattenIndices = reshape({ inputs: { x: indices }, backend: backend, attrs: { shape: [numUpdates, sliceRank] } });
|
|
var flattenX = reshape({ inputs: { x: updates }, backend: backend, attrs: { shape: [numUpdates, sliceSize] } });
|
|
var defaultValue = backend.makeTensorInfo([], 'float32', new Float32Array([0])); // scalar(0)
|
|
var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.shape.length, flattenX.shape.length, strides, flattenShape);
|
|
var res = backend.runWebGLProgram(program, [flattenX, flattenIndices, defaultValue], flattenX.dtype);
|
|
var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: shape } });
|
|
backend.disposeIntermediateTensorInfo(flattenIndices);
|
|
backend.disposeIntermediateTensorInfo(flattenX);
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
backend.disposeIntermediateTensorInfo(defaultValue);
|
|
return reshaped;
|
|
}
|
|
var scatterNdConfig = {
|
|
kernelName: tf.ScatterNd,
|
|
backendName: 'webgl',
|
|
kernelFunc: scatterNd
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SelectProgram = /** @class */ (function () {
|
|
function SelectProgram(cRank, shape, rank) {
|
|
this.variableNames = ['c', 'a', 'b'];
|
|
this.outputShape = shape;
|
|
var cCoords;
|
|
var abCoords;
|
|
if (rank > 4) {
|
|
throw Error("Where for rank " + rank + " is not yet supported");
|
|
}
|
|
if (rank === 1) {
|
|
abCoords = "resRC";
|
|
cCoords = "resRC";
|
|
}
|
|
else {
|
|
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w'];
|
|
var cCoordVars = [];
|
|
var abCoordVars = [];
|
|
for (var i = 0; i < shape.length; i++) {
|
|
abCoordVars.push("" + currentCoords[i]);
|
|
if (i < cRank) {
|
|
cCoordVars.push("" + currentCoords[i]);
|
|
}
|
|
}
|
|
cCoords = cCoordVars.join();
|
|
abCoords = abCoordVars.join();
|
|
}
|
|
var dtype = getCoordsDataType(rank);
|
|
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n ";
|
|
}
|
|
return SelectProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function select(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var condition = inputs.condition, t = inputs.t, e = inputs.e;
|
|
var program = new SelectProgram(condition.shape.length, t.shape, t.shape.length);
|
|
return backend.runWebGLProgram(program, [condition, t, e], tf.upcastType(t.dtype, e.dtype));
|
|
}
|
|
var selectConfig = {
|
|
kernelName: tf.Select,
|
|
backendName: 'webgl',
|
|
kernelFunc: select
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = " + tf.backend_util.SELU_SCALEALPHA + ";\n float scale = " + tf.backend_util.SELU_SCALE + ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n";
|
|
var selu = unaryKernelFunc({ opSnippet: SELU });
|
|
var seluConfig = {
|
|
kernelName: tf.Selu,
|
|
backendName: 'webgl',
|
|
kernelFunc: selu,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SIGMOID$2 = "return 1.0 / (1.0 + exp(-1.0 * x));";
|
|
var sigmoid = unaryKernelFunc({ opSnippet: SIGMOID$2 });
|
|
var sigmoidConfig = {
|
|
kernelName: tf.Sigmoid,
|
|
backendName: 'webgl',
|
|
kernelFunc: sigmoid,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// Sign does not propagate NANs.
|
|
var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n";
|
|
var sign = unaryKernelFunc({ opSnippet: SIGN });
|
|
var signConfig = {
|
|
kernelName: tf.Sign,
|
|
backendName: 'webgl',
|
|
kernelFunc: sign,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SIN = CHECK_NAN_SNIPPET_UNARY + "\n return sin(x);\n";
|
|
var sin = unaryKernelFunc({ opSnippet: SIN });
|
|
var sinConfig = {
|
|
kernelName: tf.Sin,
|
|
backendName: 'webgl',
|
|
kernelFunc: sin,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n";
|
|
var sinh = unaryKernelFunc({ opSnippet: SINH });
|
|
var sinhConfig = {
|
|
kernelName: tf.Sinh,
|
|
backendName: 'webgl',
|
|
kernelFunc: sinh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n";
|
|
var softplus = unaryKernelFunc({ opSnippet: SOFTPLUS });
|
|
var softplusConfig = {
|
|
kernelName: tf.Softplus,
|
|
backendName: 'webgl',
|
|
kernelFunc: softplus,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var spaceToBatchND = function (args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var blockShape = attrs.blockShape, paddings = attrs.paddings;
|
|
tf.util.assert(x.shape.length <= 4, function () { return 'spaceToBatchND for rank > 4 with a WebGL backend not ' +
|
|
'implemented yet'; });
|
|
var prod = blockShape.reduce(function (a, b) { return a * b; });
|
|
var completePaddings = [[0, 0]];
|
|
completePaddings.push.apply(completePaddings, paddings);
|
|
for (var i = 1 + blockShape.length; i < x.shape.length; ++i) {
|
|
completePaddings.push([0, 0]);
|
|
}
|
|
var toDispose = [];
|
|
var paddedX = padV2({
|
|
inputs: { x: x },
|
|
backend: backend,
|
|
attrs: { paddings: completePaddings, constantValue: 0 }
|
|
});
|
|
var reshapedPaddedShape = tf.backend_util.getReshaped(paddedX.shape, blockShape, prod, false);
|
|
var permutedReshapedPaddedPermutation = tf.backend_util.getPermuted(reshapedPaddedShape.length, blockShape.length, false);
|
|
var flattenShape = tf.backend_util.getReshapedPermuted(paddedX.shape, blockShape, prod, false);
|
|
var reshapedPaddedX = reshape({ inputs: { x: paddedX }, backend: backend, attrs: { shape: reshapedPaddedShape } });
|
|
var paddedXT = transpose({
|
|
inputs: { x: reshapedPaddedX },
|
|
backend: backend,
|
|
attrs: { perm: permutedReshapedPaddedPermutation }
|
|
});
|
|
var result = reshape({ inputs: { x: paddedXT }, backend: backend, attrs: { shape: flattenShape } });
|
|
toDispose.push(paddedX);
|
|
toDispose.push(reshapedPaddedX);
|
|
toDispose.push(paddedXT);
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return result;
|
|
};
|
|
var spaceToBatchNDConfig = {
|
|
kernelName: tf.SpaceToBatchND,
|
|
backendName: 'webgl',
|
|
kernelFunc: spaceToBatchND
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseFillEmptyRows(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var indices = inputs.indices, values = inputs.values, denseShape = inputs.denseShape, defaultValue = inputs.defaultValue;
|
|
if (denseShape.shape.length !== 1) {
|
|
throw new Error("Dense shape must be a vector, saw:\n " + denseShape.shape);
|
|
}
|
|
if (indices.shape.length !== 2) {
|
|
throw new Error("Indices must be a matrix, saw:\n " + indices.shape);
|
|
}
|
|
if (values.shape.length !== 1) {
|
|
throw new Error("Values must be a vector, saw:\n " + values.shape);
|
|
}
|
|
if (defaultValue.shape.length !== 0) {
|
|
throw new Error("Default value must be a scalar, saw:\n " + defaultValue.shape);
|
|
}
|
|
var $indices = backend.readSync(indices.dataId);
|
|
var $values = backend.readSync(values.dataId);
|
|
var $denseShape = backend.readSync(denseShape.dataId);
|
|
var $defaultValue = backend.readSync(defaultValue.dataId)[0];
|
|
var _a = sparseFillEmptyRowsImplCPU($indices, indices.shape, indices.dtype, $values, values.dtype, $denseShape, $defaultValue), outputIndices = _a[0], outputIndicesShape = _a[1], outputValues = _a[2], emptyRowIndicator = _a[3], reverseIndexMap = _a[4];
|
|
return [
|
|
backend.makeTensorInfo(outputIndicesShape, indices.dtype, outputIndices),
|
|
backend.makeTensorInfo([outputIndicesShape[0]], values.dtype, outputValues),
|
|
backend.makeTensorInfo([emptyRowIndicator.length], 'bool', new Uint8Array(emptyRowIndicator.map(function (value) { return Number(value); }))),
|
|
backend.makeTensorInfo([reverseIndexMap.length], indices.dtype, new Int32Array(reverseIndexMap)),
|
|
];
|
|
}
|
|
var sparseFillEmptyRowsConfig = {
|
|
kernelName: tf.SparseFillEmptyRows,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseFillEmptyRows,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseReshape(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var inputIndices = inputs.inputIndices, inputShape = inputs.inputShape, newShape = inputs.newShape;
|
|
if (inputIndices.shape.length !== 2) {
|
|
throw new Error("Input indices should be a matrix but received shape " + inputIndices.shape);
|
|
}
|
|
if (inputShape.shape.length !== 1) {
|
|
throw new Error("Input shape should be a vector but received shape " + inputShape.shape);
|
|
}
|
|
if (newShape.shape.length !== 1) {
|
|
throw new Error("Target shape should be a vector but received shape " + newShape.shape);
|
|
}
|
|
var $inputShape = Array.from(backend.readSync(inputShape.dataId));
|
|
var $inputIndices = backend.readSync(inputIndices.dataId);
|
|
var targetShape = Array.from(backend.readSync(newShape.dataId));
|
|
var _a = sparseReshapeImplCPU($inputIndices, inputIndices.shape, inputIndices.dtype, $inputShape, targetShape), newIndices = _a[0], indicesShape = _a[1], outputShape = _a[2];
|
|
return [
|
|
backend.makeTensorInfo(indicesShape, inputIndices.dtype, newIndices),
|
|
backend.makeTensorInfo([outputShape.length], newShape.dtype, new Int32Array(outputShape)),
|
|
];
|
|
}
|
|
var sparseReshapeConfig = {
|
|
kernelName: tf.SparseReshape,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseReshape,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseSegmentMean(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var data = inputs.data, indices = inputs.indices, segmentIds = inputs.segmentIds;
|
|
if (data.shape.length < 1) {
|
|
throw new Error("Data should be at least 1 dimensional but received scalar");
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
|
|
}
|
|
var $data = backend.readSync(data.dataId);
|
|
var $indices = backend.readSync(indices.dataId);
|
|
var $segmentIds = backend.readSync(segmentIds.dataId);
|
|
var _a = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds, true), outputData = _a[0], outputDataShape = _a[1];
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
var sparseSegmentMeanConfig = {
|
|
kernelName: tf.SparseSegmentMean,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseSegmentMean,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseSegmentSum(args) {
|
|
var inputs = args.inputs, backend = args.backend;
|
|
var data = inputs.data, indices = inputs.indices, segmentIds = inputs.segmentIds;
|
|
if (data.shape.length < 1) {
|
|
throw new Error("Data should be at least 1 dimensional but received scalar");
|
|
}
|
|
if (indices.shape.length !== 1) {
|
|
throw new Error("Indices should be a vector but received shape\n " + indices.shape);
|
|
}
|
|
if (segmentIds.shape.length !== 1) {
|
|
throw new Error("Segment ids should be a vector but received shape\n " + segmentIds.shape);
|
|
}
|
|
var $data = backend.readSync(data.dataId);
|
|
var $indices = backend.readSync(indices.dataId);
|
|
var $segmentIds = backend.readSync(segmentIds.dataId);
|
|
var _a = sparseSegmentReductionImplCPU($data, data.shape, data.dtype, $indices, $segmentIds), outputData = _a[0], outputDataShape = _a[1];
|
|
return backend.makeTensorInfo(outputDataShape, data.dtype, outputData);
|
|
}
|
|
var sparseSegmentSumConfig = {
|
|
kernelName: tf.SparseSegmentSum,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseSegmentSum,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function sparseToDense(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var sparseIndices = inputs.sparseIndices, sparseValues = inputs.sparseValues, defaultValue = inputs.defaultValue;
|
|
var outputShape = attrs.outputShape;
|
|
var _a = tf.backend_util.calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, strides = _a.strides, outputSize = _a.outputSize;
|
|
var sumDupeIndices = false;
|
|
var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.shape.length, sparseValues.shape.length, strides, [outputSize, 1], sumDupeIndices);
|
|
var res = backend.runWebGLProgram(program, [sparseValues, sparseIndices, defaultValue], sparseValues.dtype);
|
|
var reshaped = reshape({ inputs: { x: res }, backend: backend, attrs: { shape: outputShape } });
|
|
backend.disposeIntermediateTensorInfo(res);
|
|
return reshaped;
|
|
}
|
|
var sparseToDenseConfig = {
|
|
kernelName: tf.SparseToDense,
|
|
backendName: 'webgl',
|
|
kernelFunc: sparseToDense
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function splitV(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var numOrSizeSplits = attrs.numOrSizeSplits, axis = attrs.axis;
|
|
var $axis = tf.util.parseAxisParam(axis, x.shape)[0];
|
|
var splitSizes = tf.backend_util.prepareSplitSize(x, numOrSizeSplits, $axis);
|
|
var xRank = x.shape.length;
|
|
var begin = new Array(xRank).fill(0);
|
|
var size = x.shape.slice();
|
|
return splitSizes.map(function (s) {
|
|
var sliceSize = size.slice();
|
|
sliceSize[$axis] = s;
|
|
var sliceT = slice({ inputs: { x: x }, backend: backend, attrs: { begin: begin, size: sliceSize } });
|
|
begin[$axis] += s;
|
|
return sliceT;
|
|
});
|
|
}
|
|
var splitVConfig = {
|
|
kernelName: tf.SplitV,
|
|
backendName: 'webgl',
|
|
kernelFunc: splitV
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SQRT = "return sqrt(x);";
|
|
var sqrt = unaryKernelFunc({ opSnippet: SQRT });
|
|
var sqrtConfig = {
|
|
kernelName: tf.Sqrt,
|
|
backendName: 'webgl',
|
|
kernelFunc: sqrt
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2019 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SQUARE = "return x * x;";
|
|
var square = unaryKernelFunc({ opSnippet: SQUARE });
|
|
var squareConfig = {
|
|
kernelName: tf.Square,
|
|
backendName: 'webgl',
|
|
kernelFunc: square,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
|
|
var squaredDifference = binaryKernelFunc({ opSnippet: SQUARED_DIFFERENCE, packedOpSnippet: SQUARED_DIFFERENCE });
|
|
var squaredDifferenceConfig = {
|
|
kernelName: tf.SquaredDifference,
|
|
backendName: 'webgl',
|
|
kernelFunc: squaredDifference,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function step(_a) {
|
|
var inputs = _a.inputs, attrs = _a.attrs, backend = _a.backend;
|
|
var x = inputs.x;
|
|
var opSnippet = CHECK_NAN_SNIPPET + ("\n return x > 0.0 ? 1.0 : float(" + attrs.alpha + ");\n ");
|
|
var program = new UnaryOpProgram(x.shape, opSnippet);
|
|
return backend.runWebGLProgram(program, [x], x.dtype);
|
|
}
|
|
var stepConfig = {
|
|
kernelName: tf.Step,
|
|
backendName: 'webgl',
|
|
kernelFunc: step,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var StridedSliceProgram = /** @class */ (function () {
|
|
function StridedSliceProgram(begin, strides, size) {
|
|
this.variableNames = ['x'];
|
|
this.outputShape = size;
|
|
var rank = size.length;
|
|
var inputDtype = getCoordsDataType(size.length);
|
|
var dtype = getCoordsDataType(size.length);
|
|
var newCoords = '';
|
|
if (rank === 1) {
|
|
newCoords = 'coords * strides + begin';
|
|
}
|
|
else {
|
|
var outputAxis_1 = 0;
|
|
newCoords =
|
|
size.map(function (_, i) {
|
|
outputAxis_1++;
|
|
return size.length === 1 ?
|
|
"coords * strides[" + i + "] + begin[" + i + "]" :
|
|
"coords[" + (outputAxis_1 - 1) + "] * strides[" + i + "] + begin[" + i + "]";
|
|
})
|
|
.join(',');
|
|
}
|
|
this.userCode = "\n " + inputDtype + " begin = " + inputDtype + "(" + begin + ");\n " + inputDtype + " strides = " + inputDtype + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n setOutput(getX(" + newCoords + "));\n }\n ";
|
|
}
|
|
return StridedSliceProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stridedSlice(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var begin = attrs.begin, end = attrs.end, strides = attrs.strides, beginMask = attrs.beginMask, endMask = attrs.endMask, ellipsisMask = attrs.ellipsisMask, newAxisMask = attrs.newAxisMask, shrinkAxisMask = attrs.shrinkAxisMask;
|
|
var _a = tf.slice_util.sliceInfo(x.shape, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask), nonStrided = _a.nonStrided, $begin = _a.$begin, $strides = _a.$strides, size = _a.size, newShape = _a.newShape, outShape = _a.outShape;
|
|
var $x = reshape({ inputs: { x: x }, backend: backend, attrs: { shape: newShape } });
|
|
var result;
|
|
if (nonStrided) {
|
|
var sliced = slice({ inputs: { x: $x }, backend: backend, attrs: { begin: $begin, size: size } });
|
|
result = reshape({ inputs: { x: sliced }, backend: backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo(sliced);
|
|
}
|
|
else if (outShape.some(function (axis) { return axis === 0; })) {
|
|
result = backend.makeTensorInfo(outShape, x.dtype, []);
|
|
}
|
|
else {
|
|
var shouldExecuteOnCPU = backend.shouldExecuteOnCPU([$x]);
|
|
if (shouldExecuteOnCPU) {
|
|
var xTexData = backend.texData.get($x.dataId);
|
|
var values = xTexData.values;
|
|
var xBuf = tf.buffer($x.shape, $x.dtype, values);
|
|
var resultValues = stridedSliceImplCPU(outShape, xBuf, $strides, $begin);
|
|
result = backend.makeTensorInfo(outShape, $x.dtype, resultValues.values);
|
|
}
|
|
else {
|
|
var program = new StridedSliceProgram($begin, $strides, outShape);
|
|
result = backend.runWebGLProgram(program, [$x], $x.dtype);
|
|
}
|
|
}
|
|
var resultReshaped = reshape({ inputs: { x: result }, backend: backend, attrs: { shape: outShape } });
|
|
backend.disposeIntermediateTensorInfo($x);
|
|
backend.disposeIntermediateTensorInfo(result);
|
|
return resultReshaped;
|
|
}
|
|
var stridedSliceConfig = {
|
|
kernelName: tf.StridedSlice,
|
|
backendName: 'webgl',
|
|
kernelFunc: stridedSlice
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stringNGrams(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var separator = attrs.separator, nGramWidths = attrs.nGramWidths, leftPad = attrs.leftPad, rightPad = attrs.rightPad, padWidth = attrs.padWidth, preserveShortSequences = attrs.preserveShortSequences;
|
|
var data = inputs.data, dataSplits = inputs.dataSplits;
|
|
var $data = backend.readSync(data.dataId);
|
|
var $dataSplits = backend.readSync(dataSplits.dataId);
|
|
var _a = stringNGramsImplCPU($data, $dataSplits, separator, nGramWidths, leftPad, rightPad, padWidth, preserveShortSequences), nGrams = _a[0], nGramsSplits = _a[1];
|
|
return [
|
|
backend.makeTensorInfo([nGrams.length], 'string', nGrams),
|
|
backend.makeTensorInfo(dataSplits.shape, 'int32', nGramsSplits),
|
|
];
|
|
}
|
|
var stringNGramsConfig = {
|
|
kernelName: tf.StringNGrams,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringNGrams,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stringSplit(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var skipEmpty = attrs.skipEmpty;
|
|
var input = inputs.input, delimiter = inputs.delimiter;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (input.shape.length !== 1) {
|
|
throw new Error("Input must be a vector, got shape: " + input.shape);
|
|
}
|
|
if (delimiter.shape.length !== 0) {
|
|
throw new Error("Delimiter must be a scalar, got shape: " + delimiter.shape);
|
|
}
|
|
var $input = backend.readSync(input.dataId);
|
|
var $delimiter = backend.readSync(delimiter.dataId)[0];
|
|
var _a = stringSplitImplCPU($input, $delimiter, skipEmpty), indices = _a[0], values = _a[1], shape = _a[2];
|
|
var outputSize = values.length;
|
|
return [
|
|
backend.makeTensorInfo([outputSize, 2], 'int32', indices),
|
|
backend.makeTensorInfo([outputSize], 'string', values),
|
|
backend.makeTensorInfo([2], 'int32', new Int32Array(shape))
|
|
];
|
|
}
|
|
var stringSplitConfig = {
|
|
kernelName: tf.StringSplit,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringSplit,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function stringToHashBucketFast(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var numBuckets = attrs.numBuckets;
|
|
var input = inputs.input;
|
|
if (input.dtype !== 'string') {
|
|
throw new Error('Input must be of datatype string');
|
|
}
|
|
if (numBuckets <= 0) {
|
|
throw new Error("Number of buckets must be at least 1");
|
|
}
|
|
var $input = backend.readSync(input.dataId);
|
|
var output = stringToHashBucketFastImplCPU($input, numBuckets);
|
|
return backend.makeTensorInfo(input.shape, 'int32', output);
|
|
}
|
|
var stringToHashBucketFastConfig = {
|
|
kernelName: tf.StringToHashBucketFast,
|
|
backendName: 'webgl',
|
|
kernelFunc: stringToHashBucketFast,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TAN = "return tan(x);";
|
|
var tan = unaryKernelFunc({ opSnippet: TAN });
|
|
var tanConfig = {
|
|
kernelName: tf.Tan,
|
|
backendName: 'webgl',
|
|
kernelFunc: tan,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n";
|
|
var tanh = unaryKernelFunc({ opSnippet: TANH });
|
|
var tanhConfig = {
|
|
kernelName: tf.Tanh,
|
|
backendName: 'webgl',
|
|
kernelFunc: tanh,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2017 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TileProgram = /** @class */ (function () {
|
|
function TileProgram(aShape, reps) {
|
|
this.variableNames = ['A'];
|
|
var outputShape = new Array(aShape.length);
|
|
for (var i = 0; i < outputShape.length; i++) {
|
|
outputShape[i] = aShape[i] * reps[i];
|
|
}
|
|
this.outputShape = outputShape;
|
|
this.rank = outputShape.length;
|
|
var dtype = getCoordsDataType(this.rank);
|
|
var sourceCoords = getSourceCoords$2(aShape);
|
|
this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n ";
|
|
}
|
|
return TileProgram;
|
|
}());
|
|
function getSourceCoords$2(aShape) {
|
|
var rank = aShape.length;
|
|
if (rank > 5) {
|
|
throw Error("Tile for rank " + rank + " is not yet supported");
|
|
}
|
|
if (rank === 1) {
|
|
return "imod(resRC, " + aShape[0] + ")";
|
|
}
|
|
var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u'];
|
|
var sourceCoords = [];
|
|
for (var i = 0; i < aShape.length; i++) {
|
|
sourceCoords.push("imod(" + currentCoords[i] + ", " + aShape[i] + ")");
|
|
}
|
|
return sourceCoords.join();
|
|
}
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function tile(params) {
|
|
var inputs = params.inputs, backend = params.backend, attrs = params.attrs;
|
|
var x = inputs.x;
|
|
var reps = attrs.reps;
|
|
// tile gpu program cannot handle rank > 5 case.
|
|
if (x.dtype === 'string' || x.shape.length > 5) {
|
|
// Even thought string tensor is always on CPU, just to be consistent on how
|
|
// to access tensor data.
|
|
var data = backend.readSync(x.dataId);
|
|
var value = x.dtype === 'string' ?
|
|
data.map(function (d) { return tf.util.decodeString(d); }) :
|
|
data;
|
|
var buf = tf.buffer(x.shape, x.dtype, value);
|
|
var outBuf = tileImplCPU(buf, reps);
|
|
return backend.makeTensorInfo(outBuf.shape, outBuf.dtype, outBuf.values);
|
|
}
|
|
var program = new TileProgram(x.shape, reps);
|
|
var output = backend.runWebGLProgram(program, [x], x.dtype);
|
|
return output;
|
|
}
|
|
var tileConfig = {
|
|
kernelName: tf.Tile,
|
|
backendName: 'webgl',
|
|
kernelFunc: tile,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function topK(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x;
|
|
var k = attrs.k, sorted = attrs.sorted;
|
|
var xVals = backend.readSync(x.dataId);
|
|
var _a = topKImplCPU(xVals, x.shape, x.dtype, k), allTopKVals = _a[0], allTopKIndices = _a[1];
|
|
return [
|
|
backend.makeTensorInfo(allTopKVals.shape, allTopKVals.dtype, allTopKVals.values),
|
|
backend.makeTensorInfo(allTopKIndices.shape, allTopKIndices.dtype, allTopKIndices.values)
|
|
];
|
|
}
|
|
var topKConfig = {
|
|
kernelName: tf.TopK,
|
|
backendName: 'webgl',
|
|
kernelFunc: topK
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var TransformProgram = /** @class */ (function () {
|
|
function TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape) {
|
|
this.variableNames = ['Image', 'Transforms'];
|
|
this.outputShape = outShape;
|
|
var interpolationModeId = interpolation === 'nearest' ? 1 : 2;
|
|
var fillModeId;
|
|
switch (fillMode) {
|
|
case 'constant':
|
|
fillModeId = 1;
|
|
break;
|
|
case 'reflect':
|
|
fillModeId = 2;
|
|
break;
|
|
case 'wrap':
|
|
fillModeId = 3;
|
|
break;
|
|
case 'nearest':
|
|
fillModeId = 4;
|
|
break;
|
|
default:
|
|
fillModeId = 1;
|
|
break;
|
|
}
|
|
this.userCode = "\n float mapCoord(float outCoord, float len) {\n float inCoord = outCoord;\n if(" + fillModeId + " == 2) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n if (inCoord < sz2) {\n inCoord = sz2 * float(int(float(-inCoord / sz2))) +\n inCoord;\n }\n inCoord = inCoord < -len ? inCoord + sz2 : -inCoord - 1.0;\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz2 = 2.0 * len;\n inCoord -= sz2 * float(int(float(inCoord / sz2)));\n if (inCoord >= len) {\n inCoord = sz2 - inCoord - 1.0;\n }\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (" + fillModeId + " == 3) {\n if (inCoord < 0.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord += len * (float(int(float(-inCoord / sz))) + 1.0);\n }\n } else if (inCoord > len - 1.0) {\n if (len <= 1.0) {\n inCoord = 0.0;\n } else {\n float sz = len - 1.0;\n inCoord -= len * float(int(float(inCoord / sz)));\n }\n }\n return clamp(inCoord, 0.0, len - 1.0);\n } else if (" + fillModeId + " == 4) {\n return clamp(outCoord, 0.0, len - 1.0);\n } else {\n return outCoord;\n }\n }\n\n float readWithFillValue(int batch, int coordY, int coordX,\n int channel) {\n float outputValue;\n if (0 <= coordY && coordY < " + imageHeight + " && 0 <= coordX && coordX < " + imageWidth + ") {\n outputValue = getImage(batch, coordY, coordX, channel);\n } else {\n outputValue = float(" + fillValue + ");\n }\n return outputValue;\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n float outputValue;\n int batch = coords[0];\n int x = coords[2];\n int y = coords[1];\n int channel = coords[3];\n float xf = float(x);\n float yf = float(y);\n float a1 = getTransforms(batch, 0);\n float a2 = getTransforms(batch, 1);\n float a3 = getTransforms(batch, 2);\n float b1 = getTransforms(batch, 3);\n float b2 = getTransforms(batch, 4);\n float b3 = getTransforms(batch, 5);\n float c1 = getTransforms(batch, 6);\n float c2 = getTransforms(batch, 7);\n float projection = c1 * xf + c2 * yf + 1.0;\n if (projection == 0.0) {\n outputValue = float(" + fillValue + ");\n } else {\n float inX = (a1 * xf + a2 * yf + a3) / projection;\n float inY = (b1 * xf + b2 * yf + b3) / projection;\n float mapX = mapCoord(inX, float(" + imageWidth + "));\n float mapY = mapCoord(inY, float(" + imageHeight + "));\n\n if (" + interpolationModeId + " == 1) {\n int coordY = int(round(mapY));\n int coordX = int(round(mapX));\n outputValue = readWithFillValue(batch, coordY, coordX,\n channel);\n } else {\n float yFloor = floor(mapY);\n float xFloor = floor(mapX);\n float yCeil = yFloor + 1.0;\n float xCeil = xFloor + 1.0;\n float valueYFloor = (xCeil - mapX) *\n readWithFillValue(batch, int(yFloor), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yFloor), int(xCeil), channel);\n float valueYCeil = (xCeil - mapX) *\n readWithFillValue(batch, int(yCeil), int(xFloor), channel) +\n (mapX - xFloor) *\n readWithFillValue(batch, int(yCeil), int(xCeil), channel);\n outputValue = (yCeil - mapY) * valueYFloor +\n (mapY - yFloor) * valueYCeil;\n }\n }\n setOutput(outputValue);\n }\n ";
|
|
}
|
|
return TransformProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2021 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function transform(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var image = inputs.image, transforms = inputs.transforms;
|
|
var interpolation = attrs.interpolation, fillMode = attrs.fillMode, fillValue = attrs.fillValue, outputShape = attrs.outputShape;
|
|
var _a = image.shape, batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], numChannels = _a[3];
|
|
var _b = outputShape != null ? outputShape : [imageHeight, imageWidth], outHeight = _b[0], outWidth = _b[1];
|
|
var outShape = [batch, outHeight, outWidth,
|
|
numChannels];
|
|
var program = new TransformProgram(imageHeight, imageWidth, interpolation, fillMode, fillValue, outShape);
|
|
return backend.runWebGLProgram(program, [image, transforms], 'float32');
|
|
}
|
|
var transformConfig = {
|
|
kernelName: tf.Transform,
|
|
backendName: 'webgl',
|
|
kernelFunc: transform
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the License);
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an AS IS BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function unique(args) {
|
|
var inputs = args.inputs, attrs = args.attrs, backend = args.backend;
|
|
var axis = attrs.axis;
|
|
var x = inputs.x;
|
|
assertNotComplex(x, 'unique');
|
|
// For now, always forward calculation to the CPU backend.
|
|
console.warn('WARNING: ', 'UI might be locked temporarily as data is being downloaded');
|
|
var values = backend.readSync(x.dataId);
|
|
var _a = uniqueImplCPU(values, axis, x.shape, x.dtype), outputValues = _a.outputValues, outputShape = _a.outputShape, indices = _a.indices;
|
|
return [
|
|
backend.makeTensorInfo(outputShape, x.dtype, outputValues),
|
|
backend.makeTensorInfo([indices.length], 'int32', indices),
|
|
];
|
|
}
|
|
var uniqueConfig = {
|
|
kernelName: tf.Unique,
|
|
backendName: 'webgl',
|
|
kernelFunc: unique,
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function unpack(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var value = inputs.value;
|
|
var axis = attrs.axis;
|
|
if (axis < 0) {
|
|
axis += value.shape.length;
|
|
}
|
|
var x = value;
|
|
var xRank = x.shape.length;
|
|
var num = value.shape[axis];
|
|
var outShape = new Array(xRank - 1);
|
|
var outIndex = 0;
|
|
for (var i = 0; i < xRank; i++) {
|
|
if (i !== axis) {
|
|
outShape[outIndex++] = x.shape[i];
|
|
}
|
|
}
|
|
var toDispose = [];
|
|
var begin = new Array(xRank).fill(0);
|
|
var size = x.shape.slice();
|
|
size[axis] = 1;
|
|
var res = new Array(num);
|
|
for (var i = 0; i < res.length; i++) {
|
|
begin[axis] = i;
|
|
var sliced = slice({ inputs: { x: x }, backend: backend, attrs: { begin: begin, size: size } });
|
|
var reshaped = reshape({ inputs: { x: sliced }, backend: backend, attrs: { shape: outShape } });
|
|
res[i] = reshaped;
|
|
toDispose.push(sliced);
|
|
}
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return res;
|
|
}
|
|
var unpackConfig = {
|
|
kernelName: tf.Unpack,
|
|
backendName: 'webgl',
|
|
kernelFunc: unpack
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2018 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
var SegmentOpProgram = /** @class */ (function () {
|
|
function SegmentOpProgram(segOpInfo, segOpType) {
|
|
this.variableNames = ['x', 'segmentIds'];
|
|
var windowSize = segOpInfo.windowSize;
|
|
var batchSize = segOpInfo.batchSize;
|
|
var inSize = segOpInfo.inSize;
|
|
var numSegments = segOpInfo.numSegments;
|
|
var outSize = numSegments * Math.ceil(inSize / windowSize);
|
|
this.outputShape = [batchSize, outSize];
|
|
var initializationValue = '0.0';
|
|
var returnValue = "sumValue";
|
|
var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4;
|
|
var windowSizeVec4Remainder = windowSize % 4;
|
|
var updateSnippet = "\n sumValue += dot(values, segFilter);\n ";
|
|
var checkValueOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n ";
|
|
}
|
|
var checkSegmentIdOutOfBounds = '';
|
|
if (inSize % windowSize > 0) {
|
|
checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n ";
|
|
}
|
|
this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n ";
|
|
}
|
|
return SegmentOpProgram;
|
|
}());
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
function unsortedSegmentSum(args) {
|
|
var inputs = args.inputs, backend = args.backend, attrs = args.attrs;
|
|
var x = inputs.x, segmentIds = inputs.segmentIds;
|
|
var numSegments = attrs.numSegments;
|
|
var xRank = x.shape.length;
|
|
var toDispose = [];
|
|
var axis = 0;
|
|
var permutation = tf.backend_util.getAxesPermutation([axis], xRank);
|
|
var permutedX = x;
|
|
if (permutation != null) {
|
|
permutedX = transpose({ inputs: { x: x }, backend: backend, attrs: { perm: permutation } });
|
|
toDispose.push(permutedX);
|
|
axis = tf.backend_util.getInnerMostAxes(1, xRank)[0];
|
|
}
|
|
var outShape = tf.backend_util.segment_util.computeOutShape(permutedX.shape, axis, numSegments);
|
|
var inSize = tf.util.sizeFromShape([permutedX.shape[axis]]);
|
|
var a2D = reshape({ inputs: { x: permutedX }, backend: backend, attrs: { shape: [-1, inSize] } });
|
|
toDispose.push(a2D);
|
|
var outputDType = tf.sumOutType(x.dtype);
|
|
var segOpCompute = function (x, segOpType, segmentIds, dtype, numSegments) {
|
|
var batchSize = x.shape[0];
|
|
var inSize = x.shape[1];
|
|
var windowSize = tf.backend_util.segment_util.segOpComputeOptimalWindowSize(inSize, numSegments);
|
|
var segOpInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, numSegments: numSegments };
|
|
var program = new SegmentOpProgram(segOpInfo, segOpType);
|
|
var output = backend.compileAndRun(program, [x, segmentIds], dtype);
|
|
toDispose.push(output);
|
|
// No need to run another GPGPU program.
|
|
if (output.shape[1] === numSegments) {
|
|
return output;
|
|
}
|
|
var rangeInfo = range({
|
|
backend: backend,
|
|
attrs: { start: 0, stop: numSegments, step: 1, dtype: 'float32' }
|
|
});
|
|
var tileInfo = tile({
|
|
inputs: { x: rangeInfo },
|
|
backend: backend,
|
|
attrs: { reps: [inSize / windowSize] }
|
|
});
|
|
toDispose.push(rangeInfo);
|
|
toDispose.push(tileInfo);
|
|
var result = segOpCompute(output, segOpType, tileInfo, dtype, numSegments);
|
|
return result;
|
|
};
|
|
var segOpResult = segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments);
|
|
var reshaped = reshape({ inputs: { x: segOpResult }, backend: backend, attrs: { shape: outShape } });
|
|
var result = reshaped;
|
|
if (permutation != null) {
|
|
toDispose.push(reshaped);
|
|
var perm = tf.backend_util.getUndoAxesPermutation(permutation);
|
|
result = transpose({ inputs: { x: result }, backend: backend, attrs: { perm: perm } });
|
|
}
|
|
toDispose.forEach(function (t) { return backend.disposeIntermediateTensorInfo(t); });
|
|
return result;
|
|
}
|
|
var unsortedSegmentSumConfig = {
|
|
kernelName: tf.UnsortedSegmentSum,
|
|
backendName: 'webgl',
|
|
kernelFunc: unsortedSegmentSum
|
|
};
|
|
|
|
/**
|
|
* @license
|
|
* Copyright 2020 Google LLC. All Rights Reserved.
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
* =============================================================================
|
|
*/
|
|
// List all kernel configs here
|
|
var kernelConfigs = [
|
|
LRNConfig,
|
|
LRNGradConfig,
|
|
_fusedMatMulConfig,
|
|
absConfig,
|
|
acosConfig,
|
|
acoshConfig,
|
|
addConfig,
|
|
addNConfig,
|
|
allConfig,
|
|
anyConfig,
|
|
argMaxConfig,
|
|
argMinConfig,
|
|
asinConfig,
|
|
asinhConfig,
|
|
atan2Config,
|
|
atanConfig,
|
|
atanhConfig,
|
|
avgPool3DConfig,
|
|
avgPoolConfig,
|
|
avgPoolGrad3DConfig,
|
|
avgPoolGradConfig,
|
|
batchMatMulConfig,
|
|
batchNormConfig,
|
|
batchToSpaceNDConfig,
|
|
bincountConfig,
|
|
castConfig,
|
|
ceilConfig,
|
|
clipByValueConfig,
|
|
complexAbsConfig,
|
|
complexConfig,
|
|
concatConfig,
|
|
conv2DBackpropFilterConfig,
|
|
conv2DBackpropInputConfig,
|
|
conv2DConfig,
|
|
conv3DBackpropFilterV2Config,
|
|
conv3DBackpropInputConfig,
|
|
conv3DConfig,
|
|
cosConfig,
|
|
coshConfig,
|
|
cropAndResizeConfig,
|
|
cumsumConfig,
|
|
denseBincountConfig,
|
|
depthToSpaceConfig,
|
|
depthwiseConv2dNativeBackpropFilterConfig,
|
|
depthwiseConv2dNativeBackpropInputConfig,
|
|
depthwiseConv2dNativeConfig,
|
|
diagConfig,
|
|
dilation2DConfig,
|
|
einsumConfig,
|
|
eluConfig,
|
|
eluGradConfig,
|
|
equalConfig,
|
|
erfConfig,
|
|
expConfig,
|
|
expandDimsConfig,
|
|
expm1Config,
|
|
fftConfig,
|
|
fillConfig,
|
|
flipLeftRightConfig,
|
|
floorConfig,
|
|
floorDivConfig,
|
|
fromPixelsConfig,
|
|
fusedConv2DConfig,
|
|
fusedDepthwiseConv2DConfig,
|
|
gatherNdConfig,
|
|
gatherV2Config,
|
|
greaterConfig,
|
|
greaterEqualConfig,
|
|
identityConfig,
|
|
ifftConfig,
|
|
imagConfig,
|
|
isFiniteConfig,
|
|
isInfConfig,
|
|
isNaNConfig,
|
|
leakyReluConfig,
|
|
lessConfig,
|
|
lessEqualConfig,
|
|
linSpaceConfig,
|
|
log1pConfig,
|
|
logConfig,
|
|
logicalAndConfig,
|
|
logicalNotConfig,
|
|
logicalOrConfig,
|
|
maxConfig,
|
|
maxPool3DConfig,
|
|
maxPoolConfig,
|
|
maxPoolGrad3DConfig,
|
|
maxPoolGradConfig,
|
|
maxPoolWithArgmaxConfig,
|
|
maximumConfig,
|
|
meanConfig,
|
|
minConfig,
|
|
minimumConfig,
|
|
mirrorPadConfig,
|
|
modConfig,
|
|
multinomialConfig,
|
|
multiplyConfig,
|
|
negConfig,
|
|
nonMaxSuppressionV3Config,
|
|
nonMaxSuppressionV4Config,
|
|
nonMaxSuppressionV5Config,
|
|
notEqualConfig,
|
|
oneHotConfig,
|
|
onesLikeConfig,
|
|
packConfig,
|
|
padV2Config,
|
|
powConfig,
|
|
preluConfig,
|
|
prodConfig,
|
|
rangeConfig,
|
|
realConfig,
|
|
realDivConfig,
|
|
reciprocalConfig,
|
|
relu6Config,
|
|
reluConfig,
|
|
reshapeConfig,
|
|
resizeBilinearConfig,
|
|
resizeBilinearGradConfig,
|
|
resizeNearestNeighborConfig,
|
|
resizeNearestNeighborGradConfig,
|
|
reverseConfig,
|
|
rotateWithOffsetConfig,
|
|
roundConfig,
|
|
rsqrtConfig,
|
|
scatterNdConfig,
|
|
selectConfig,
|
|
seluConfig,
|
|
sigmoidConfig,
|
|
signConfig,
|
|
sinConfig,
|
|
sinhConfig,
|
|
sliceConfig,
|
|
softmaxConfig,
|
|
softplusConfig,
|
|
spaceToBatchNDConfig,
|
|
sparseFillEmptyRowsConfig,
|
|
sparseReshapeConfig,
|
|
sparseSegmentMeanConfig,
|
|
sparseSegmentSumConfig,
|
|
sparseToDenseConfig,
|
|
splitVConfig,
|
|
sqrtConfig,
|
|
squareConfig,
|
|
squaredDifferenceConfig,
|
|
stepConfig,
|
|
stridedSliceConfig,
|
|
stringNGramsConfig,
|
|
stringSplitConfig,
|
|
stringToHashBucketFastConfig,
|
|
subConfig,
|
|
sumConfig,
|
|
tanConfig,
|
|
tanhConfig,
|
|
tileConfig,
|
|
topKConfig,
|
|
transformConfig,
|
|
transposeConfig,
|
|
uniqueConfig,
|
|
unpackConfig,
|
|
unsortedSegmentSumConfig,
|
|
zerosLikeConfig
|
|
];
|
|
for (var _i = 0, kernelConfigs_1 = kernelConfigs; _i < kernelConfigs_1.length; _i++) {
|
|
var kernelConfig = kernelConfigs_1[_i];
|
|
tf.registerKernel(kernelConfig);
|
|
}
|
|
|
|
exports.GPGPUContext = GPGPUContext;
|
|
exports.MathBackendWebGL = MathBackendWebGL;
|
|
exports.forceHalfFloat = forceHalfFloat;
|
|
exports.gpgpu_util = gpgpu_util;
|
|
exports.setWebGLContext = setWebGLContext;
|
|
exports.version_webgl = version;
|
|
exports.webgl = webgl;
|
|
exports.webgl_util = webgl_util;
|
|
|
|
Object.defineProperty(exports, '__esModule', { value: true });
|
|
|
|
})));
|
|
//# sourceMappingURL=tf-backend-webgl.js.map
|