/** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function (global, factory) { typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) : typeof define === 'function' && define.amd ? define(['exports'], factory) : (global = global || self, factory(global.tf = global.tf || {})); }(this, (function (exports) { 'use strict'; /*! ***************************************************************************** Copyright (c) Microsoft Corporation. Permission to use, copy, modify, and/or distribute this software for any purpose with or without fee is hereby granted. THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. ***************************************************************************** */ /* global Reflect, Promise */ var extendStatics = function(d, b) { extendStatics = Object.setPrototypeOf || ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; return extendStatics(d, b); }; function __extends(d, b) { extendStatics(d, b); function __() { this.constructor = d; } d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); } function __awaiter(thisArg, _arguments, P, generator) { function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); } return new (P || (P = Promise))(function (resolve, reject) { function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); } step((generator = generator.apply(thisArg, _arguments || [])).next()); }); } function __generator(thisArg, body) { var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; function verb(n) { return function (v) { return step([n, v]); }; } function step(op) { if (f) throw new TypeError("Generator is already executing."); while (_) try { if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; if (y = 0, t) op = [op[0] & 2, t.value]; switch (op[0]) { case 0: case 1: t = op; break; case 4: _.label++; return { value: op[1], done: false }; case 5: _.label++; y = op[1]; op = [0]; continue; case 7: op = _.ops.pop(); _.trys.pop(); continue; default: if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } if (t[2]) _.ops.pop(); _.trys.pop(); continue; } op = body.call(thisArg, _); } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; } } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var EPSILON_FLOAT32 = 1e-7; var EPSILON_FLOAT16 = 1e-4; /** Convenient class for storing tensor-related data. */ var DataStorage = /** @class */ (function () { function DataStorage(backend, dataMover) { this.backend = backend; this.dataMover = dataMover; this.data = new WeakMap(); this.dataIdsCount = 0; } DataStorage.prototype.get = function (dataId) { if (!this.data.has(dataId)) { this.dataMover.moveData(this.backend, dataId); } return this.data.get(dataId); }; DataStorage.prototype.set = function (dataId, value) { this.dataIdsCount++; this.data.set(dataId, value); }; DataStorage.prototype.has = function (dataId) { return this.data.has(dataId); }; DataStorage.prototype.delete = function (dataId) { this.dataIdsCount--; return this.data.delete(dataId); }; DataStorage.prototype.numDataIds = function () { return this.dataIdsCount; }; return DataStorage; }()); /** * The interface that defines the kernels that should be implemented when * adding a new backend. New backends don't need to implement every one of the * methods, this can be done gradually (throw an error for unimplemented * methods). */ var KernelBackend = /** @class */ (function () { function KernelBackend() { } KernelBackend.prototype.refCount = function (dataId) { return notYetImplemented('refCount'); }; KernelBackend.prototype.incRef = function (dataId) { return notYetImplemented('incRef'); }; KernelBackend.prototype.timerAvailable = function () { return true; }; KernelBackend.prototype.time = function (f) { return notYetImplemented('time'); }; KernelBackend.prototype.read = function (dataId) { return notYetImplemented('read'); }; KernelBackend.prototype.readSync = function (dataId) { return notYetImplemented('readSync'); }; KernelBackend.prototype.numDataIds = function () { return notYetImplemented('numDataIds'); }; KernelBackend.prototype.disposeData = function (dataId, force) { return notYetImplemented('disposeData'); }; KernelBackend.prototype.write = function (values, shape, dtype) { return notYetImplemented('write'); }; KernelBackend.prototype.move = function (dataId, values, shape, dtype, refCount) { return notYetImplemented('move'); }; KernelBackend.prototype.memory = function () { return notYetImplemented('memory'); }; /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ KernelBackend.prototype.floatPrecision = function () { return notYetImplemented('floatPrecision'); }; /** Returns the smallest representable number. */ KernelBackend.prototype.epsilon = function () { return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; }; KernelBackend.prototype.dispose = function () { return notYetImplemented('dispose'); }; return KernelBackend; }()); function notYetImplemented(kernelName) { throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " + "This kernel may not be supported by the tfjs backend you have chosen"); } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Shuffles the array in-place using Fisher-Yates algorithm. * * ```js * const a = [1, 2, 3, 4, 5]; * tf.util.shuffle(a); * console.log(a); * ``` * * @param array The array to shuffle in-place. * * @doc {heading: 'Util', namespace: 'util'} */ // tslint:disable-next-line:no-any function shuffle(array) { var counter = array.length; var temp = 0; var index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element with it temp = array[counter]; array[counter] = array[index]; array[index] = temp; } } /** * Shuffles two arrays in-place the same way using Fisher-Yates algorithm. * * ```js * const a = [1,2,3,4,5]; * const b = [11,22,33,44,55]; * tf.util.shuffleCombo(a, b); * console.log(a, b); * ``` * * @param array The first array to shuffle in-place. * @param array2 The second array to shuffle in-place with the same permutation * as the first array. * * @doc {heading: 'Util', namespace: 'util'} */ function shuffleCombo( // tslint:disable-next-line:no-any array, // tslint:disable-next-line:no-any array2) { if (array.length !== array2.length) { throw new Error("Array sizes must match to be shuffled together " + ("First array length was " + array.length) + ("Second array length was " + array2.length)); } var counter = array.length; var temp, temp2; var index = 0; // While there are elements in the array while (counter > 0) { // Pick a random index index = (Math.random() * counter) | 0; // Decrease counter by 1 counter--; // And swap the last element of each array with it temp = array[counter]; temp2 = array2[counter]; array[counter] = array[index]; array2[counter] = array2[index]; array[index] = temp; array2[index] = temp2; } } /** Clamps a value to a specified range. */ function clamp(min, x, max) { return Math.max(min, Math.min(x, max)); } function nearestLargerEven(val) { return val % 2 === 0 ? val : val + 1; } function sum(arr) { var sum = 0; for (var i = 0; i < arr.length; i++) { sum += arr[i]; } return sum; } /** * Returns a sample from a uniform [a, b) distribution. * * @param a The minimum support (inclusive). * @param b The maximum support (exclusive). * @return A pseudorandom number on the half-open interval [a,b). */ function randUniform(a, b) { var r = Math.random(); return (b * r) + (1 - r) * a; } /** Returns the squared Euclidean distance between two vectors. */ function distSquared(a, b) { var result = 0; for (var i = 0; i < a.length; i++) { var diff = Number(a[i]) - Number(b[i]); result += diff * diff; } return result; } /** * Asserts that the expression is true. Otherwise throws an error with the * provided message. * * ```js * const x = 2; * tf.util.assert(x === 2, 'x is not 2'); * ``` * * @param expr The expression to assert (as a boolean). * @param msg A function that returns the message to report when throwing an * error. We use a function for performance reasons. * * @doc {heading: 'Util', namespace: 'util'} */ function assert(expr, msg) { if (!expr) { throw new Error(typeof msg === 'string' ? msg : msg()); } } function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); }); } function assertNonNull(a) { assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; }); } // NOTE: We explicitly type out what T extends instead of any so that // util.flatten on a nested array of number doesn't try to infer T as a // number[][], causing us to explicitly type util.flatten(). /** * Flattens an arbitrarily nested array. * * ```js * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; * const flat = tf.util.flatten(a); * console.log(flat); * ``` * * @param arr The nested array to flatten. * @param result The destination array which holds the elements. * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults * to false. * * @doc {heading: 'Util', namespace: 'util'} */ function flatten(arr, result, skipTypedArray) { if (result === void 0) { result = []; } if (skipTypedArray === void 0) { skipTypedArray = false; } if (result == null) { result = []; } if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { for (var i = 0; i < arr.length; ++i) { flatten(arr[i], result, skipTypedArray); } } else { result.push(arr); } return result; } /** * Returns the size (number of elements) of the tensor given its shape. * * ```js * const shape = [3, 4, 2]; * const size = tf.util.sizeFromShape(shape); * console.log(size); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function sizeFromShape(shape) { if (shape.length === 0) { // Scalar. return 1; } var size = shape[0]; for (var i = 1; i < shape.length; i++) { size *= shape[i]; } return size; } function isScalarShape(shape) { return shape.length === 0; } function arraysEqual(n1, n2) { if (n1 === n2) { return true; } if (n1 == null || n2 == null) { return false; } if (n1.length !== n2.length) { return false; } for (var i = 0; i < n1.length; i++) { if (n1[i] !== n2[i]) { return false; } } return true; } function isInt(a) { return a % 1 === 0; } function tanh(x) { // tslint:disable-next-line:no-any if (Math.tanh != null) { // tslint:disable-next-line:no-any return Math.tanh(x); } if (x === Infinity) { return 1; } else if (x === -Infinity) { return -1; } else { var e2x = Math.exp(2 * x); return (e2x - 1) / (e2x + 1); } } function sizeToSquarishShape(size) { var width = Math.ceil(Math.sqrt(size)); return [width, Math.ceil(size / width)]; } /** * Creates a new array with randomized indicies to a given quantity. * * ```js * const randomTen = tf.util.createShuffledIndices(10); * console.log(randomTen); * ``` * * @param number Quantity of how many shuffled indicies to create. * * @doc {heading: 'Util', namespace: 'util'} */ function createShuffledIndices(n) { var shuffledIndices = new Uint32Array(n); for (var i = 0; i < n; ++i) { shuffledIndices[i] = i; } shuffle(shuffledIndices); return shuffledIndices; } function rightPad(a, size) { if (size <= a.length) { return a; } return a + ' '.repeat(size - a.length); } function repeatedTry(checkFn, delayFn, maxCounter) { if (delayFn === void 0) { delayFn = function (counter) { return 0; }; } return new Promise(function (resolve, reject) { var tryCount = 0; var tryFn = function () { if (checkFn()) { resolve(); return; } tryCount++; var nextBackoff = delayFn(tryCount); if (maxCounter != null && tryCount >= maxCounter) { reject(); return; } setTimeout(tryFn, nextBackoff); }; tryFn(); }); } /** * Given the full size of the array and a shape that may contain -1 as the * implicit dimension, returns the inferred shape where -1 is replaced. * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. * * @param shape The shape, which may contain -1 in some dimension. * @param size The full size (number of elements) of the array. * @return The inferred shape where -1 is replaced with the inferred size. */ function inferFromImplicitShape(shape, size) { var shapeProd = 1; var implicitIdx = -1; for (var i = 0; i < shape.length; ++i) { if (shape[i] >= 0) { shapeProd *= shape[i]; } else if (shape[i] === -1) { if (implicitIdx !== -1) { throw Error("Shapes can only have 1 implicit size. " + ("Found -1 at dim " + implicitIdx + " and dim " + i)); } implicitIdx = i; } else if (shape[i] < 0) { throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i); } } if (implicitIdx === -1) { if (size > 0 && size !== shapeProd) { throw Error("Size(" + size + ") must match the product of shape " + shape); } return shape; } if (shapeProd === 0) { throw Error("Cannot infer the missing size in [" + shape + "] when " + "there are 0 elements"); } if (size % shapeProd !== 0) { throw Error("The implicit shape can't be a fractional number. " + ("Got " + size + " / " + shapeProd)); } var newShape = shape.slice(); newShape[implicitIdx] = size / shapeProd; return newShape; } function parseAxisParam(axis, shape) { var rank = shape.length; // Normalize input axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); // Check for valid range assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + ("got axis " + axis); }); // Check for only integers assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + ("got axis " + axis); }); // Handle negative axis. return axis.map(function (a) { return a < 0 ? rank + a : a; }); } /** Reduces the shape by removing all dimensions of shape 1. */ function squeezeShape(shape, axis) { var newShape = []; var keptDims = []; var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; var axes = (axis == null || isEmptyArray) ? null : parseAxisParam(axis, shape).sort(); var j = 0; for (var i = 0; i < shape.length; ++i) { if (axes != null) { if (axes[j] === i && shape[i] !== 1) { throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1"); } if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { newShape.push(shape[i]); keptDims.push(i); } if (axes[j] <= i) { j++; } } if (shape[i] !== 1) { newShape.push(shape[i]); keptDims.push(i); } } return { newShape: newShape, keptDims: keptDims }; } function getTypedArrayFromDType(dtype, size) { var values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else { throw new Error("Unknown data type " + dtype); } return values; } function getArrayFromDType(dtype, size) { var values = null; if (dtype == null || dtype === 'float32') { values = new Float32Array(size); } else if (dtype === 'int32') { values = new Int32Array(size); } else if (dtype === 'bool') { values = new Uint8Array(size); } else if (dtype === 'string') { values = new Array(size); } else { throw new Error("Unknown data type " + dtype); } return values; } function checkConversionForErrors(vals, dtype) { for (var i = 0; i < vals.length; i++) { var num = vals[i]; if (isNaN(num) || !isFinite(num)) { throw Error("A tensor of type " + dtype + " being uploaded contains " + num + "."); } } } /** Returns true if the dtype is valid. */ function isValidDtype(dtype) { return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || dtype === 'int32' || dtype === 'string'; } /** * Returns true if the new type can't encode the old type without loss of * precision. */ function hasEncodingLoss(oldType, newType) { if (newType === 'complex64') { return false; } if (newType === 'float32' && oldType !== 'complex64') { return false; } if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { return false; } if (newType === 'bool' && oldType === 'bool') { return false; } return true; } function isTypedArray(a) { return a instanceof Float32Array || a instanceof Int32Array || a instanceof Uint8Array; } function bytesPerElement(dtype) { if (dtype === 'float32' || dtype === 'int32') { return 4; } else if (dtype === 'complex64') { return 8; } else if (dtype === 'bool') { return 1; } else { throw new Error("Unknown dtype " + dtype); } } /** * Returns the approximate number of bytes allocated in the string array - 2 * bytes per character. Computing the exact bytes for a native string in JS is * not possible since it depends on the encoding of the html page that serves * the website. */ function bytesFromStringArray(arr) { if (arr == null) { return 0; } var bytes = 0; arr.forEach(function (x) { return bytes += x.length; }); return bytes; } /** Returns true if the value is a string. */ function isString(value) { return typeof value === 'string' || value instanceof String; } function isBoolean(value) { return typeof value === 'boolean'; } function isNumber(value) { return typeof value === 'number'; } function inferDtype(values) { if (Array.isArray(values)) { return inferDtype(values[0]); } if (values instanceof Float32Array) { return 'float32'; } else if (values instanceof Int32Array || values instanceof Uint8Array) { return 'int32'; } else if (isNumber(values)) { return 'float32'; } else if (isString(values)) { return 'string'; } else if (isBoolean(values)) { return 'bool'; } return 'float32'; } function isFunction(f) { return !!(f && f.constructor && f.call && f.apply); } function nearestDivisor(size, start) { for (var i = start; i < size; ++i) { if (size % i === 0) { return i; } } return size; } function computeStrides(shape) { var rank = shape.length; if (rank < 2) { return []; } // Last dimension has implicit stride of 1, thus having D-1 (instead of D) // strides. var strides = new Array(rank - 1); strides[rank - 2] = shape[rank - 1]; for (var i = rank - 3; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } return strides; } function createNestedArray(offset, shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } var ret = new Array(); if (shape.length === 1) { var d = shape[0] * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = a[offset + i]; } } else { var d = shape[0]; var rest = shape.slice(1); var len = rest.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); for (var i = 0; i < d; i++) { ret[i] = createNestedArray(offset + i * len, rest, a, isComplex); } } return ret; } // Provide a nested array of TypedArray in given shape. function toNestedArray(shape, a, isComplex) { if (isComplex === void 0) { isComplex = false; } if (shape.length === 0) { // Scalar type should return a single number. return a[0]; } var size = shape.reduce(function (acc, c) { return acc * c; }) * (isComplex ? 2 : 1); if (size === 0) { // A tensor with shape zero should be turned into empty list. return []; } if (size !== a.length) { throw new Error("[" + shape + "] does not match the input size " + a.length + (isComplex ? ' for a complex tensor' : '') + "."); } return createNestedArray(0, shape, a, isComplex); } function makeOnesTypedArray(size, dtype) { var array = makeZerosTypedArray(size, dtype); for (var i = 0; i < array.length; i++) { array[i] = 1; } return array; } function makeZerosTypedArray(size, dtype) { if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(size); } else if (dtype === 'int32') { return new Int32Array(size); } else if (dtype === 'bool') { return new Uint8Array(size); } else { throw new Error("Unknown data type " + dtype); } } /** * Make nested `TypedArray` filled with zeros. * @param shape The shape information for the nested array. * @param dtype dtype of the array element. */ function makeZerosNestedTypedArray(shape, dtype) { var size = shape.reduce(function (prev, curr) { return prev * curr; }, 1); if (dtype == null || dtype === 'float32') { return toNestedArray(shape, new Float32Array(size)); } else if (dtype === 'int32') { return toNestedArray(shape, new Int32Array(size)); } else if (dtype === 'bool') { return toNestedArray(shape, new Uint8Array(size)); } else { throw new Error("Unknown data type " + dtype); } } function assertNonNegativeIntegerDimensions(shape) { shape.forEach(function (dimSize) { assert(Number.isInteger(dimSize) && dimSize >= 0, function () { return "Tensor must have a shape comprised of positive integers but got " + ("shape [" + shape + "]."); }); }); } /** * Computes flat index for a given location (multidimentionsal index) in a * Tensor/multidimensional array. * * @param locs Location in the tensor. * @param rank Rank of the tensor. * @param strides Tensor strides. */ function locToIndex(locs, rank, strides) { if (rank === 0) { return 0; } else if (rank === 1) { return locs[0]; } var index = locs[locs.length - 1]; for (var i = 0; i < locs.length - 1; ++i) { index += strides[i] * locs[i]; } return index; } /** * Computes the location (multidimensional index) in a tensor/multidimentional * array for a given flat index. * * @param index Index in flat array. * @param rank Rank of tensor. * @param strides Strides of tensor. */ function indexToLoc(index, rank, strides) { if (rank === 0) { return []; } else if (rank === 1) { return [index]; } var locs = new Array(rank); for (var i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / strides[i]); index -= locs[i] * strides[i]; } locs[locs.length - 1] = index; return locs; } /** * This method asserts whether an object is a Promise instance. * @param object */ // tslint:disable-next-line: no-any function isPromise(object) { // We chose to not use 'obj instanceOf Promise' for two reasons: // 1. It only reliably works for es6 Promise, not other Promise // implementations. // 2. It doesn't work with framework that uses zone.js. zone.js monkey patch // the async calls, so it is possible the obj (patched) is comparing to a // pre-patched Promise. return object && object.then && typeof object.then === 'function'; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; /** * The environment contains evaluated flags as well as the registered platform. * This is always used as a global singleton and can be retrieved with * `tf.env()`. * * @doc {heading: 'Environment'} */ var Environment = /** @class */ (function () { // tslint:disable-next-line: no-any function Environment(global) { this.global = global; this.flags = {}; this.flagRegistry = {}; this.urlFlags = {}; // Jasmine spies on this in 'environment_test.ts' this.getQueryParams = getQueryParams; this.populateURLFlags(); } Environment.prototype.setPlatform = function (platformName, platform) { if (this.platform != null) { console.warn("Platform " + this.platformName + " has already been set. " + ("Overwriting the platform with " + platform + ".")); } this.platformName = platformName; this.platform = platform; }; Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; // Override the flag value from the URL. This has to happen here because the // environment is initialized before flags get registered. if (this.urlFlags[flagName] != null) { var flagValue = this.urlFlags[flagName]; console.warn("Setting feature override from URL " + flagName + ": " + flagValue + "."); this.set(flagName, flagValue); } }; Environment.prototype.getAsync = function (flagName) { return __awaiter(this, void 0, void 0, function () { var _a, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (flagName in this.flags) { return [2 /*return*/, this.flags[flagName]]; } _a = this.flags; _b = flagName; return [4 /*yield*/, this.evaluateFlag(flagName)]; case 1: _a[_b] = _c.sent(); return [2 /*return*/, this.flags[flagName]]; } }); }); }; Environment.prototype.get = function (flagName) { if (flagName in this.flags) { return this.flags[flagName]; } var flagValue = this.evaluateFlag(flagName); if (isPromise(flagValue)) { throw new Error("Flag " + flagName + " cannot be synchronously evaluated. " + "Please use getAsync() instead."); } this.flags[flagName] = flagValue; return this.flags[flagName]; }; Environment.prototype.getNumber = function (flagName) { return this.get(flagName); }; Environment.prototype.getBool = function (flagName) { return this.get(flagName); }; Environment.prototype.getFlags = function () { return this.flags; }; Object.defineProperty(Environment.prototype, "features", { // For backwards compatibility. get: function () { return this.flags; }, enumerable: true, configurable: true }); Environment.prototype.set = function (flagName, value) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot set flag " + flagName + " as it has not been registered."); } this.flags[flagName] = value; if (this.flagRegistry[flagName].setHook != null) { this.flagRegistry[flagName].setHook(value); } }; Environment.prototype.evaluateFlag = function (flagName) { if (this.flagRegistry[flagName] == null) { throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found."); } return this.flagRegistry[flagName].evaluationFn(); }; Environment.prototype.setFlags = function (flags) { this.flags = Object.assign({}, flags); }; Environment.prototype.reset = function () { this.flags = {}; this.urlFlags = {}; this.populateURLFlags(); }; Environment.prototype.populateURLFlags = function () { var _this = this; if (typeof this.global === 'undefined' || typeof this.global.location === 'undefined' || typeof this.global.location.search === 'undefined') { return; } var urlParams = this.getQueryParams(this.global.location.search); if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); keyValues.forEach(function (keyValue) { var _a = keyValue.split(':'), key = _a[0], value = _a[1]; _this.urlFlags[key] = parseValue(key, value); }); } }; return Environment; }()); function getQueryParams(queryString) { var params = {}; queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { var t = []; for (var _i = 1; _i < arguments.length; _i++) { t[_i - 1] = arguments[_i]; } decodeParam(params, t[0], t[1]); return t.join('='); }); return params; } function decodeParam(params, name, value) { params[decodeURIComponent(name)] = decodeURIComponent(value || ''); } function parseValue(flagName, value) { value = value.toLowerCase(); if (value === 'true' || value === 'false') { return value === 'true'; } else if ("" + +value === value) { return +value; } throw new Error("Could not parse value flag value " + value + " for flag " + flagName + "."); } /** * Returns the current environment (a global singleton). * * The environment object contains the evaluated feature values as well as the * active platform. * * @doc {heading: 'Environment'} */ function env() { return exports.ENV; } exports.ENV = null; function setEnvironmentGlobal(environment) { exports.ENV = environment; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Note that the identifier globalNameSpace is scoped to this module, but will // always resolve to the same global object regardless of how the module is // resolved. // tslint:disable-next-line:no-any var globalNameSpace; // tslint:disable-next-line:no-any function getGlobalNamespace() { if (globalNameSpace == null) { // tslint:disable-next-line:no-any var ns = void 0; if (typeof (window) !== 'undefined') { ns = window; } else if (typeof (global) !== 'undefined') { ns = global; } else if (typeof (process) !== 'undefined') { ns = process; } else if (typeof (self) !== 'undefined') { ns = self; } else { throw new Error('Could not find a global object'); } globalNameSpace = ns; } return globalNameSpace; } // tslint:disable-next-line:no-any function getGlobalMap() { var ns = getGlobalNamespace(); if (ns._tfGlobals == null) { ns._tfGlobals = new Map(); } return ns._tfGlobals; } /** * Returns a globally accessible 'singleton' object. * * @param key the name of the object * @param init a function to initialize to initialize this object * the first time it is fetched. */ function getGlobal(key, init) { var globalMap = getGlobalMap(); if (globalMap.has(key)) { return globalMap.get(key); } else { var singleton = init(); globalMap.set(key, singleton); return globalMap.get(key); } } var Abs = 'Abs'; var Acos = 'Acos'; var Acosh = 'Acosh'; var Add = 'Add'; var AddN = 'AddN'; var All = 'All'; var Any = 'Any'; var ArgMax = 'ArgMax'; var ArgMin = 'ArgMin'; var Asin = 'Asin'; var Asinh = 'Asinh'; var Atan = 'Atan'; var Atanh = 'Atanh'; var Atan2 = 'Atan2'; var AvgPool = 'AvgPool'; var AvgPoolGrad = 'AvgPoolGrad'; var AvgPool3D = 'AvgPool3D'; var AvgPool3DGrad = 'AvgPool3DGrad'; var BatchMatMul = 'BatchMatMul'; var BatchToSpaceND = 'BatchToSpaceND'; var Bincount = 'Bincount'; var BroadcastTo = 'BroadcastTo'; var Cast = 'Cast'; var Ceil = 'Ceil'; var ClipByValue = 'ClipByValue'; var Complex = 'Complex'; var ComplexAbs = 'ComplexAbs'; var Concat = 'Concat'; var Conv2D = 'Conv2D'; var Conv2DBackpropFilter = 'Conv2DBackpropFilter'; var Conv2DBackpropInput = 'Conv2DBackpropInput'; var Conv3D = 'Conv3D'; var Conv3DBackpropFilterV2 = 'Conv3DBackpropFilterV2'; var Conv3DBackpropInputV2 = 'Conv3DBackpropInputV2'; var Cos = 'Cos'; var Cosh = 'Cosh'; var Cumsum = 'Cumsum'; var CropAndResize = 'CropAndResize'; var DenseBincount = 'DenseBincount'; var DepthToSpace = 'DepthToSpace'; var DepthwiseConv2dNative = 'DepthwiseConv2dNative'; var DepthwiseConv2dNativeBackpropFilter = 'DepthwiseConv2dNativeBackpropFilter'; var DepthwiseConv2dNativeBackpropInput = 'DepthwiseConv2dNativeBackpropInput'; var Diag = 'Diag'; var Dilation2D = 'Dilation2D'; var Dilation2DBackpropInput = 'Dilation2DBackpropInput'; var Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; var RealDiv = 'RealDiv'; var Einsum = 'Einsum'; var Elu = 'Elu'; var EluGrad = 'EluGrad'; var Erf = 'Erf'; var Equal = 'Equal'; var Exp = 'Exp'; var ExpandDims = 'ExpandDims'; var Expm1 = 'Expm1'; var FFT = 'FFT'; var Fill = 'Fill'; var FlipLeftRight = 'FlipLeftRight'; var Floor = 'Floor'; var FloorDiv = 'FloorDiv'; var FusedBatchNorm = 'FusedBatchNorm'; var GatherV2 = 'GatherV2'; var GatherNd = 'GatherNd'; var Greater = 'Greater'; var GreaterEqual = 'GreaterEqual'; var Identity = 'Identity'; var IFFT = 'IFFT'; var Imag = 'Imag'; var IsFinite = 'IsFinite'; var IsInf = 'IsInf'; var IsNan = 'IsNan'; var LeakyRelu = 'LeakyRelu'; var Less = 'Less'; var LessEqual = 'LessEqual'; var LinSpace = 'LinSpace'; var Log = 'Log'; var Log1p = 'Log1p'; var LogicalAnd = 'LogicalAnd'; var LogicalNot = 'LogicalNot'; var LogicalOr = 'LogicalOr'; var LogSoftmax = 'LogSoftmax'; var LRN = 'LRN'; var LRNGrad = 'LRNGrad'; var Max = 'Max'; var Maximum = 'Maximum'; var MaxPool = 'MaxPool'; var MaxPoolGrad = 'MaxPoolGrad'; var MaxPool3D = 'MaxPool3D'; var MaxPool3DGrad = 'MaxPool3DGrad'; var MaxPoolWithArgmax = 'MaxPoolWithArgmax'; var Mean = 'Mean'; var Min = 'Min'; var Minimum = 'Minimum'; var MirrorPad = 'MirrorPad'; var Mod = 'Mod'; var Multinomial = 'Multinomial'; var Multiply = 'Multiply'; var Neg = 'Neg'; var NotEqual = 'NotEqual'; var NonMaxSuppressionV3 = 'NonMaxSuppressionV3'; var NonMaxSuppressionV4 = 'NonMaxSuppressionV4'; var NonMaxSuppressionV5 = 'NonMaxSuppressionV5'; var OnesLike = 'OnesLike'; var OneHot = 'OneHot'; var Pack = 'Pack'; var PadV2 = 'PadV2'; var Pool = 'Pool'; var Pow = 'Pow'; var Prelu = 'Prelu'; var Prod = 'Prod'; var Range = 'Range'; var Real = 'Real'; var Reciprocal = 'Reciprocal'; var Relu = 'Relu'; var Reshape = 'Reshape'; var ResizeNearestNeighbor = 'ResizeNearestNeighbor'; var ResizeNearestNeighborGrad = 'ResizeNearestNeighborGrad'; var ResizeBilinear = 'ResizeBilinear'; var ResizeBilinearGrad = 'ResizeBilinearGrad'; var Relu6 = 'Relu6'; var Reverse = 'Reverse'; var Round = 'Round'; var Rsqrt = 'Rsqrt'; var ScatterNd = 'ScatterNd'; var Select = 'Select'; var Selu = 'Selu'; var Slice = 'Slice'; var Sin = 'Sin'; var Sinh = 'Sinh'; var Sign = 'Sign'; var Sigmoid = 'Sigmoid'; var Softplus = 'Softplus'; var Sqrt = 'Sqrt'; var Sum = 'Sum'; var SpaceToBatchND = 'SpaceToBatchND'; var SplitV = 'SplitV'; var Softmax = 'Softmax'; var SparseFillEmptyRows = 'SparseFillEmptyRows'; var SparseReshape = 'SparseReshape'; var SparseSegmentMean = 'SparseSegmentMean'; var SparseSegmentSum = 'SparseSegmentSum'; var SparseToDense = 'SparseToDense'; var SquaredDifference = 'SquaredDifference'; var Square = 'Square'; var StridedSlice = 'StridedSlice'; var StringNGrams = 'StringNGrams'; var StringSplit = 'StringSplit'; var StringToHashBucketFast = 'StringToHashBucketFast'; var Sub = 'Sub'; var Tan = 'Tan'; var Tanh = 'Tanh'; var Tile = 'Tile'; var TopK = 'TopK'; var Transform = 'Transform'; var Transpose = 'Transpose'; var Unique = 'Unique'; var Unpack = 'Unpack'; var UnsortedSegmentSum = 'UnsortedSegmentSum'; var ZerosLike = 'ZerosLike'; /** * TensorFlow.js-only kernels */ var Step = 'Step'; var FromPixels = 'FromPixels'; var RotateWithOffset = 'RotateWithOffset'; var _FusedMatMul = '_FusedMatMul'; var FusedConv2D = 'FusedConv2D'; var FusedDepthwiseConv2D = 'FusedDepthwiseConv2D'; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var kernelRegistry = getGlobal('kernelRegistry', function () { return new Map(); }); var gradRegistry = getGlobal('gradRegistry', function () { return new Map(); }); /** * Returns the kernel function (code) associated with the provided names. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. */ function getKernel(kernelName, backendName) { var key = makeKey(kernelName, backendName); return kernelRegistry.get(key); } /** * Returns the registered gradient info associated with the provided kernel. * @param kernelName The official TF kernel name. */ function getGradient(kernelName) { return gradRegistry.get(kernelName); } function getKernelsForBackend(backendName) { var it = kernelRegistry.entries(); var result = []; while (true) { var _a = it.next(), done = _a.done, value = _a.value; if (done) { break; } var key = value[0], config = value[1]; var backend = key.split('_')[0]; if (backend === backendName) { result.push(config); } } return result; } /** * Registers the function (forward pass) for the kernel in a global registry. * * @param config A config object with the following properties: * - `kernelName` The official name of the kernel. * - `backendName` The official name of the backend. * - `kernelFunc` The function to run during the forward pass of the kernel. * - `setupFunc` Optional. Gets called once, after the backend initializes. * - `disposeFunc` Optional. Gets called once, right before the backend is * disposed. */ function registerKernel(config) { var kernelName = config.kernelName, backendName = config.backendName; var key = makeKey(kernelName, backendName); if (kernelRegistry.has(key)) { console.warn("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is already registered")); } kernelRegistry.set(key, config); } /** * Registers a gradient function for a given kernel in the global registry, * to be used during the back-propagation of that kernel. * * @param config An object with the following properties: * - `kernelName` The name of the kernel that the gradient function is for. * - `gradFunc` The function to run during back-propagation. */ function registerGradient(config) { var kernelName = config.kernelName; if (gradRegistry.has(kernelName)) { // TODO (yassogba) after 3.0 assess whether we need to keep this gated // to debug mode. if (env().getBool('DEBUG')) { console.warn("Overriding the gradient for '" + kernelName + "'"); } } gradRegistry.set(kernelName, config); } /** * Removes the kernel function from the registry. * * @param kernelName The official name of the kernel. * @param backendName The official name of the backend. * */ function unregisterKernel(kernelName, backendName) { var key = makeKey(kernelName, backendName); if (!kernelRegistry.has(key)) { throw new Error("The kernel '" + kernelName + "' for backend " + ("'" + backendName + "' is not registered")); } kernelRegistry.delete(key); } /** Removes the registered gradient from the global registry. */ function unregisterGradient(kernelName) { if (!gradRegistry.has(kernelName)) { throw new Error("The gradient '" + kernelName + "' for backend is not registered"); } gradRegistry.delete(kernelName); } /** * Finds kernels that have already been registered to a backend and re-registers * them for a new backend. Useful for registering custom backends. * @param registeredBackendName Already registered backend. * @param newBackendName New backend. */ function copyRegisteredKernels(registeredBackendName, newBackendName) { var kernels = getKernelsForBackend(registeredBackendName); kernels.forEach(function (kernelConfig) { var newKernelConfig = Object.assign({}, kernelConfig, { backendName: newBackendName }); registerKernel(newKernelConfig); }); } function makeKey(kernelName, backendName) { return backendName + "_" + kernelName; } var long_1 = Long; /** * wasm optimizations, to do native i64 multiplication and divide */ var wasm = null; try { wasm = new WebAssembly.Instance(new WebAssembly.Module(new Uint8Array([ 0, 97, 115, 109, 1, 0, 0, 0, 1, 13, 2, 96, 0, 1, 127, 96, 4, 127, 127, 127, 127, 1, 127, 3, 7, 6, 0, 1, 1, 1, 1, 1, 6, 6, 1, 127, 1, 65, 0, 11, 7, 50, 6, 3, 109, 117, 108, 0, 1, 5, 100, 105, 118, 95, 115, 0, 2, 5, 100, 105, 118, 95, 117, 0, 3, 5, 114, 101, 109, 95, 115, 0, 4, 5, 114, 101, 109, 95, 117, 0, 5, 8, 103, 101, 116, 95, 104, 105, 103, 104, 0, 0, 10, 191, 1, 6, 4, 0, 35, 0, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 126, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 127, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 128, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 129, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11, 36, 1, 1, 126, 32, 0, 173, 32, 1, 173, 66, 32, 134, 132, 32, 2, 173, 32, 3, 173, 66, 32, 134, 132, 130, 34, 4, 66, 32, 135, 167, 36, 0, 32, 4, 167, 11 ])), {}).exports; } catch (e) { // no wasm support :( } /** * Constructs a 64 bit two's-complement integer, given its low and high 32 bit values as *signed* integers. * See the from* functions below for more convenient ways of constructing Longs. * @exports Long * @class A Long class for representing a 64 bit two's-complement integer value. * @param {number} low The low (signed) 32 bits of the long * @param {number} high The high (signed) 32 bits of the long * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @constructor */ function Long(low, high, unsigned) { /** * The low 32 bits as a signed value. * @type {number} */ this.low = low | 0; /** * The high 32 bits as a signed value. * @type {number} */ this.high = high | 0; /** * Whether unsigned or not. * @type {boolean} */ this.unsigned = !!unsigned; } // The internal representation of a long is the two given signed, 32-bit values. // We use 32-bit pieces because these are the size of integers on which // Javascript performs bit-operations. For operations like addition and // multiplication, we split each number into 16 bit pieces, which can easily be // multiplied within Javascript's floating-point representation without overflow // or change in sign. // // In the algorithms below, we frequently reduce the negative case to the // positive case by negating the input(s) and then post-processing the result. // Note that we must ALWAYS check specially whether those values are MIN_VALUE // (-2^63) because -MIN_VALUE == MIN_VALUE (since 2^63 cannot be represented as // a positive number, it overflows back into a negative). Not handling this // case would often result in infinite recursion. // // Common constant values ZERO, ONE, NEG_ONE, etc. are defined below the from* // methods on which they depend. /** * An indicator used to reliably determine if an object is a Long or not. * @type {boolean} * @const * @private */ Long.prototype.__isLong__; Object.defineProperty(Long.prototype, "__isLong__", { value: true }); /** * @function * @param {*} obj Object * @returns {boolean} * @inner */ function isLong(obj) { return (obj && obj["__isLong__"]) === true; } /** * Tests if the specified object is a Long. * @function * @param {*} obj Object * @returns {boolean} */ Long.isLong = isLong; /** * A cache of the Long representations of small integer values. * @type {!Object} * @inner */ var INT_CACHE = {}; /** * A cache of the Long representations of small unsigned integer values. * @type {!Object} * @inner */ var UINT_CACHE = {}; /** * @param {number} value * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromInt(value, unsigned) { var obj, cachedObj, cache; if (unsigned) { value >>>= 0; if (cache = (0 <= value && value < 256)) { cachedObj = UINT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, (value | 0) < 0 ? -1 : 0, true); if (cache) UINT_CACHE[value] = obj; return obj; } else { value |= 0; if (cache = (-128 <= value && value < 128)) { cachedObj = INT_CACHE[value]; if (cachedObj) return cachedObj; } obj = fromBits(value, value < 0 ? -1 : 0, false); if (cache) INT_CACHE[value] = obj; return obj; } } /** * Returns a Long representing the given 32 bit integer value. * @function * @param {number} value The 32 bit integer in question * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long.fromInt = fromInt; /** * @param {number} value * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromNumber(value, unsigned) { if (isNaN(value)) return unsigned ? UZERO : ZERO; if (unsigned) { if (value < 0) return UZERO; if (value >= TWO_PWR_64_DBL) return MAX_UNSIGNED_VALUE; } else { if (value <= -TWO_PWR_63_DBL) return MIN_VALUE; if (value + 1 >= TWO_PWR_63_DBL) return MAX_VALUE; } if (value < 0) return fromNumber(-value, unsigned).neg(); return fromBits((value % TWO_PWR_32_DBL) | 0, (value / TWO_PWR_32_DBL) | 0, unsigned); } /** * Returns a Long representing the given value, provided that it is a finite number. Otherwise, zero is returned. * @function * @param {number} value The number in question * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long.fromNumber = fromNumber; /** * @param {number} lowBits * @param {number} highBits * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromBits(lowBits, highBits, unsigned) { return new Long(lowBits, highBits, unsigned); } /** * Returns a Long representing the 64 bit integer that comes by concatenating the given low and high bits. Each is * assumed to use 32 bits. * @function * @param {number} lowBits The low 32 bits * @param {number} highBits The high 32 bits * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} The corresponding Long value */ Long.fromBits = fromBits; /** * @function * @param {number} base * @param {number} exponent * @returns {number} * @inner */ var pow_dbl = Math.pow; // Used 4 times (4*8 to 15+4) /** * @param {string} str * @param {(boolean|number)=} unsigned * @param {number=} radix * @returns {!Long} * @inner */ function fromString(str, unsigned, radix) { if (str.length === 0) throw Error('empty string'); if (str === "NaN" || str === "Infinity" || str === "+Infinity" || str === "-Infinity") return ZERO; if (typeof unsigned === 'number') { // For goog.math.long compatibility radix = unsigned, unsigned = false; } else { unsigned = !! unsigned; } radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); var p; if ((p = str.indexOf('-')) > 0) throw Error('interior hyphen'); else if (p === 0) { return fromString(str.substring(1), unsigned, radix).neg(); } // Do several (8) digits each time through the loop, so as to // minimize the calls to the very expensive emulated div. var radixToPower = fromNumber(pow_dbl(radix, 8)); var result = ZERO; for (var i = 0; i < str.length; i += 8) { var size = Math.min(8, str.length - i), value = parseInt(str.substring(i, i + size), radix); if (size < 8) { var power = fromNumber(pow_dbl(radix, size)); result = result.mul(power).add(fromNumber(value)); } else { result = result.mul(radixToPower); result = result.add(fromNumber(value)); } } result.unsigned = unsigned; return result; } /** * Returns a Long representation of the given string, written using the specified radix. * @function * @param {string} str The textual representation of the Long * @param {(boolean|number)=} unsigned Whether unsigned or not, defaults to signed * @param {number=} radix The radix in which the text is written (2-36), defaults to 10 * @returns {!Long} The corresponding Long value */ Long.fromString = fromString; /** * @function * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val * @param {boolean=} unsigned * @returns {!Long} * @inner */ function fromValue(val, unsigned) { if (typeof val === 'number') return fromNumber(val, unsigned); if (typeof val === 'string') return fromString(val, unsigned); // Throws for non-objects, converts non-instanceof Long: return fromBits(val.low, val.high, typeof unsigned === 'boolean' ? unsigned : val.unsigned); } /** * Converts the specified value to a Long using the appropriate from* function for its type. * @function * @param {!Long|number|string|!{low: number, high: number, unsigned: boolean}} val Value * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {!Long} */ Long.fromValue = fromValue; // NOTE: the compiler should inline these constant values below and then remove these variables, so there should be // no runtime penalty for these. /** * @type {number} * @const * @inner */ var TWO_PWR_16_DBL = 1 << 16; /** * @type {number} * @const * @inner */ var TWO_PWR_24_DBL = 1 << 24; /** * @type {number} * @const * @inner */ var TWO_PWR_32_DBL = TWO_PWR_16_DBL * TWO_PWR_16_DBL; /** * @type {number} * @const * @inner */ var TWO_PWR_64_DBL = TWO_PWR_32_DBL * TWO_PWR_32_DBL; /** * @type {number} * @const * @inner */ var TWO_PWR_63_DBL = TWO_PWR_64_DBL / 2; /** * @type {!Long} * @const * @inner */ var TWO_PWR_24 = fromInt(TWO_PWR_24_DBL); /** * @type {!Long} * @inner */ var ZERO = fromInt(0); /** * Signed zero. * @type {!Long} */ Long.ZERO = ZERO; /** * @type {!Long} * @inner */ var UZERO = fromInt(0, true); /** * Unsigned zero. * @type {!Long} */ Long.UZERO = UZERO; /** * @type {!Long} * @inner */ var ONE = fromInt(1); /** * Signed one. * @type {!Long} */ Long.ONE = ONE; /** * @type {!Long} * @inner */ var UONE = fromInt(1, true); /** * Unsigned one. * @type {!Long} */ Long.UONE = UONE; /** * @type {!Long} * @inner */ var NEG_ONE = fromInt(-1); /** * Signed negative one. * @type {!Long} */ Long.NEG_ONE = NEG_ONE; /** * @type {!Long} * @inner */ var MAX_VALUE = fromBits(0xFFFFFFFF|0, 0x7FFFFFFF|0, false); /** * Maximum signed value. * @type {!Long} */ Long.MAX_VALUE = MAX_VALUE; /** * @type {!Long} * @inner */ var MAX_UNSIGNED_VALUE = fromBits(0xFFFFFFFF|0, 0xFFFFFFFF|0, true); /** * Maximum unsigned value. * @type {!Long} */ Long.MAX_UNSIGNED_VALUE = MAX_UNSIGNED_VALUE; /** * @type {!Long} * @inner */ var MIN_VALUE = fromBits(0, 0x80000000|0, false); /** * Minimum signed value. * @type {!Long} */ Long.MIN_VALUE = MIN_VALUE; /** * @alias Long.prototype * @inner */ var LongPrototype = Long.prototype; /** * Converts the Long to a 32 bit integer, assuming it is a 32 bit integer. * @returns {number} */ LongPrototype.toInt = function toInt() { return this.unsigned ? this.low >>> 0 : this.low; }; /** * Converts the Long to a the nearest floating-point representation of this value (double, 53 bit mantissa). * @returns {number} */ LongPrototype.toNumber = function toNumber() { if (this.unsigned) return ((this.high >>> 0) * TWO_PWR_32_DBL) + (this.low >>> 0); return this.high * TWO_PWR_32_DBL + (this.low >>> 0); }; /** * Converts the Long to a string written in the specified radix. * @param {number=} radix Radix (2-36), defaults to 10 * @returns {string} * @override * @throws {RangeError} If `radix` is out of range */ LongPrototype.toString = function toString(radix) { radix = radix || 10; if (radix < 2 || 36 < radix) throw RangeError('radix'); if (this.isZero()) return '0'; if (this.isNegative()) { // Unsigned Longs are never negative if (this.eq(MIN_VALUE)) { // We need to change the Long value before it can be negated, so we remove // the bottom-most digit in this base and then recurse to do the rest. var radixLong = fromNumber(radix), div = this.div(radixLong), rem1 = div.mul(radixLong).sub(this); return div.toString(radix) + rem1.toInt().toString(radix); } else return '-' + this.neg().toString(radix); } // Do several (6) digits each time through the loop, so as to // minimize the calls to the very expensive emulated div. var radixToPower = fromNumber(pow_dbl(radix, 6), this.unsigned), rem = this; var result = ''; while (true) { var remDiv = rem.div(radixToPower), intval = rem.sub(remDiv.mul(radixToPower)).toInt() >>> 0, digits = intval.toString(radix); rem = remDiv; if (rem.isZero()) return digits + result; else { while (digits.length < 6) digits = '0' + digits; result = '' + digits + result; } } }; /** * Gets the high 32 bits as a signed integer. * @returns {number} Signed high bits */ LongPrototype.getHighBits = function getHighBits() { return this.high; }; /** * Gets the high 32 bits as an unsigned integer. * @returns {number} Unsigned high bits */ LongPrototype.getHighBitsUnsigned = function getHighBitsUnsigned() { return this.high >>> 0; }; /** * Gets the low 32 bits as a signed integer. * @returns {number} Signed low bits */ LongPrototype.getLowBits = function getLowBits() { return this.low; }; /** * Gets the low 32 bits as an unsigned integer. * @returns {number} Unsigned low bits */ LongPrototype.getLowBitsUnsigned = function getLowBitsUnsigned() { return this.low >>> 0; }; /** * Gets the number of bits needed to represent the absolute value of this Long. * @returns {number} */ LongPrototype.getNumBitsAbs = function getNumBitsAbs() { if (this.isNegative()) // Unsigned Longs are never negative return this.eq(MIN_VALUE) ? 64 : this.neg().getNumBitsAbs(); var val = this.high != 0 ? this.high : this.low; for (var bit = 31; bit > 0; bit--) if ((val & (1 << bit)) != 0) break; return this.high != 0 ? bit + 33 : bit + 1; }; /** * Tests if this Long's value equals zero. * @returns {boolean} */ LongPrototype.isZero = function isZero() { return this.high === 0 && this.low === 0; }; /** * Tests if this Long's value equals zero. This is an alias of {@link Long#isZero}. * @returns {boolean} */ LongPrototype.eqz = LongPrototype.isZero; /** * Tests if this Long's value is negative. * @returns {boolean} */ LongPrototype.isNegative = function isNegative() { return !this.unsigned && this.high < 0; }; /** * Tests if this Long's value is positive. * @returns {boolean} */ LongPrototype.isPositive = function isPositive() { return this.unsigned || this.high >= 0; }; /** * Tests if this Long's value is odd. * @returns {boolean} */ LongPrototype.isOdd = function isOdd() { return (this.low & 1) === 1; }; /** * Tests if this Long's value is even. * @returns {boolean} */ LongPrototype.isEven = function isEven() { return (this.low & 1) === 0; }; /** * Tests if this Long's value equals the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.equals = function equals(other) { if (!isLong(other)) other = fromValue(other); if (this.unsigned !== other.unsigned && (this.high >>> 31) === 1 && (other.high >>> 31) === 1) return false; return this.high === other.high && this.low === other.low; }; /** * Tests if this Long's value equals the specified's. This is an alias of {@link Long#equals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.eq = LongPrototype.equals; /** * Tests if this Long's value differs from the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.notEquals = function notEquals(other) { return !this.eq(/* validates */ other); }; /** * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.neq = LongPrototype.notEquals; /** * Tests if this Long's value differs from the specified's. This is an alias of {@link Long#notEquals}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.ne = LongPrototype.notEquals; /** * Tests if this Long's value is less than the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lessThan = function lessThan(other) { return this.comp(/* validates */ other) < 0; }; /** * Tests if this Long's value is less than the specified's. This is an alias of {@link Long#lessThan}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lt = LongPrototype.lessThan; /** * Tests if this Long's value is less than or equal the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lessThanOrEqual = function lessThanOrEqual(other) { return this.comp(/* validates */ other) <= 0; }; /** * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.lte = LongPrototype.lessThanOrEqual; /** * Tests if this Long's value is less than or equal the specified's. This is an alias of {@link Long#lessThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.le = LongPrototype.lessThanOrEqual; /** * Tests if this Long's value is greater than the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.greaterThan = function greaterThan(other) { return this.comp(/* validates */ other) > 0; }; /** * Tests if this Long's value is greater than the specified's. This is an alias of {@link Long#greaterThan}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.gt = LongPrototype.greaterThan; /** * Tests if this Long's value is greater than or equal the specified's. * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.greaterThanOrEqual = function greaterThanOrEqual(other) { return this.comp(/* validates */ other) >= 0; }; /** * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.gte = LongPrototype.greaterThanOrEqual; /** * Tests if this Long's value is greater than or equal the specified's. This is an alias of {@link Long#greaterThanOrEqual}. * @function * @param {!Long|number|string} other Other value * @returns {boolean} */ LongPrototype.ge = LongPrototype.greaterThanOrEqual; /** * Compares this Long's value with the specified's. * @param {!Long|number|string} other Other value * @returns {number} 0 if they are the same, 1 if the this is greater and -1 * if the given one is greater */ LongPrototype.compare = function compare(other) { if (!isLong(other)) other = fromValue(other); if (this.eq(other)) return 0; var thisNeg = this.isNegative(), otherNeg = other.isNegative(); if (thisNeg && !otherNeg) return -1; if (!thisNeg && otherNeg) return 1; // At this point the sign bits are the same if (!this.unsigned) return this.sub(other).isNegative() ? -1 : 1; // Both are positive if at least one is unsigned return (other.high >>> 0) > (this.high >>> 0) || (other.high === this.high && (other.low >>> 0) > (this.low >>> 0)) ? -1 : 1; }; /** * Compares this Long's value with the specified's. This is an alias of {@link Long#compare}. * @function * @param {!Long|number|string} other Other value * @returns {number} 0 if they are the same, 1 if the this is greater and -1 * if the given one is greater */ LongPrototype.comp = LongPrototype.compare; /** * Negates this Long's value. * @returns {!Long} Negated Long */ LongPrototype.negate = function negate() { if (!this.unsigned && this.eq(MIN_VALUE)) return MIN_VALUE; return this.not().add(ONE); }; /** * Negates this Long's value. This is an alias of {@link Long#negate}. * @function * @returns {!Long} Negated Long */ LongPrototype.neg = LongPrototype.negate; /** * Returns the sum of this and the specified Long. * @param {!Long|number|string} addend Addend * @returns {!Long} Sum */ LongPrototype.add = function add(addend) { if (!isLong(addend)) addend = fromValue(addend); // Divide each number into 4 chunks of 16 bits, and then sum the chunks. var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = addend.high >>> 16; var b32 = addend.high & 0xFFFF; var b16 = addend.low >>> 16; var b00 = addend.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 + b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 + b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 + b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 + b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; /** * Returns the difference of this and the specified Long. * @param {!Long|number|string} subtrahend Subtrahend * @returns {!Long} Difference */ LongPrototype.subtract = function subtract(subtrahend) { if (!isLong(subtrahend)) subtrahend = fromValue(subtrahend); return this.add(subtrahend.neg()); }; /** * Returns the difference of this and the specified Long. This is an alias of {@link Long#subtract}. * @function * @param {!Long|number|string} subtrahend Subtrahend * @returns {!Long} Difference */ LongPrototype.sub = LongPrototype.subtract; /** * Returns the product of this and the specified Long. * @param {!Long|number|string} multiplier Multiplier * @returns {!Long} Product */ LongPrototype.multiply = function multiply(multiplier) { if (this.isZero()) return ZERO; if (!isLong(multiplier)) multiplier = fromValue(multiplier); // use wasm support if present if (wasm) { var low = wasm.mul(this.low, this.high, multiplier.low, multiplier.high); return fromBits(low, wasm.get_high(), this.unsigned); } if (multiplier.isZero()) return ZERO; if (this.eq(MIN_VALUE)) return multiplier.isOdd() ? MIN_VALUE : ZERO; if (multiplier.eq(MIN_VALUE)) return this.isOdd() ? MIN_VALUE : ZERO; if (this.isNegative()) { if (multiplier.isNegative()) return this.neg().mul(multiplier.neg()); else return this.neg().mul(multiplier).neg(); } else if (multiplier.isNegative()) return this.mul(multiplier.neg()).neg(); // If both longs are small, use float multiplication if (this.lt(TWO_PWR_24) && multiplier.lt(TWO_PWR_24)) return fromNumber(this.toNumber() * multiplier.toNumber(), this.unsigned); // Divide each long into 4 chunks of 16 bits, and then add up 4x4 products. // We can skip products that would overflow. var a48 = this.high >>> 16; var a32 = this.high & 0xFFFF; var a16 = this.low >>> 16; var a00 = this.low & 0xFFFF; var b48 = multiplier.high >>> 16; var b32 = multiplier.high & 0xFFFF; var b16 = multiplier.low >>> 16; var b00 = multiplier.low & 0xFFFF; var c48 = 0, c32 = 0, c16 = 0, c00 = 0; c00 += a00 * b00; c16 += c00 >>> 16; c00 &= 0xFFFF; c16 += a16 * b00; c32 += c16 >>> 16; c16 &= 0xFFFF; c16 += a00 * b16; c32 += c16 >>> 16; c16 &= 0xFFFF; c32 += a32 * b00; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a16 * b16; c48 += c32 >>> 16; c32 &= 0xFFFF; c32 += a00 * b32; c48 += c32 >>> 16; c32 &= 0xFFFF; c48 += a48 * b00 + a32 * b16 + a16 * b32 + a00 * b48; c48 &= 0xFFFF; return fromBits((c16 << 16) | c00, (c48 << 16) | c32, this.unsigned); }; /** * Returns the product of this and the specified Long. This is an alias of {@link Long#multiply}. * @function * @param {!Long|number|string} multiplier Multiplier * @returns {!Long} Product */ LongPrototype.mul = LongPrototype.multiply; /** * Returns this Long divided by the specified. The result is signed if this Long is signed or * unsigned if this Long is unsigned. * @param {!Long|number|string} divisor Divisor * @returns {!Long} Quotient */ LongPrototype.divide = function divide(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); if (divisor.isZero()) throw Error('division by zero'); // use wasm support if present if (wasm) { // guard against signed division overflow: the largest // negative number / -1 would be 1 larger than the largest // positive number, due to two's complement. if (!this.unsigned && this.high === -0x80000000 && divisor.low === -1 && divisor.high === -1) { // be consistent with non-wasm code path return this; } var low = (this.unsigned ? wasm.div_u : wasm.div_s)( this.low, this.high, divisor.low, divisor.high ); return fromBits(low, wasm.get_high(), this.unsigned); } if (this.isZero()) return this.unsigned ? UZERO : ZERO; var approx, rem, res; if (!this.unsigned) { // This section is only relevant for signed longs and is derived from the // closure library as a whole. if (this.eq(MIN_VALUE)) { if (divisor.eq(ONE) || divisor.eq(NEG_ONE)) return MIN_VALUE; // recall that -MIN_VALUE == MIN_VALUE else if (divisor.eq(MIN_VALUE)) return ONE; else { // At this point, we have |other| >= 2, so |this/other| < |MIN_VALUE|. var halfThis = this.shr(1); approx = halfThis.div(divisor).shl(1); if (approx.eq(ZERO)) { return divisor.isNegative() ? ONE : NEG_ONE; } else { rem = this.sub(divisor.mul(approx)); res = approx.add(rem.div(divisor)); return res; } } } else if (divisor.eq(MIN_VALUE)) return this.unsigned ? UZERO : ZERO; if (this.isNegative()) { if (divisor.isNegative()) return this.neg().div(divisor.neg()); return this.neg().div(divisor).neg(); } else if (divisor.isNegative()) return this.div(divisor.neg()).neg(); res = ZERO; } else { // The algorithm below has not been made for unsigned longs. It's therefore // required to take special care of the MSB prior to running it. if (!divisor.unsigned) divisor = divisor.toUnsigned(); if (divisor.gt(this)) return UZERO; if (divisor.gt(this.shru(1))) // 15 >>> 1 = 7 ; with divisor = 8 ; true return UONE; res = UZERO; } // Repeat the following until the remainder is less than other: find a // floating-point that approximates remainder / other *from below*, add this // into the result, and subtract it from the remainder. It is critical that // the approximate value is less than or equal to the real value so that the // remainder never becomes negative. rem = this; while (rem.gte(divisor)) { // Approximate the result of division. This may be a little greater or // smaller than the actual value. approx = Math.max(1, Math.floor(rem.toNumber() / divisor.toNumber())); // We will tweak the approximate result by changing it in the 48-th digit or // the smallest non-fractional digit, whichever is larger. var log2 = Math.ceil(Math.log(approx) / Math.LN2), delta = (log2 <= 48) ? 1 : pow_dbl(2, log2 - 48), // Decrease the approximation until it is smaller than the remainder. Note // that if it is too large, the product overflows and is negative. approxRes = fromNumber(approx), approxRem = approxRes.mul(divisor); while (approxRem.isNegative() || approxRem.gt(rem)) { approx -= delta; approxRes = fromNumber(approx, this.unsigned); approxRem = approxRes.mul(divisor); } // We know the answer can't be zero... and actually, zero would cause // infinite recursion since we would make no progress. if (approxRes.isZero()) approxRes = ONE; res = res.add(approxRes); rem = rem.sub(approxRem); } return res; }; /** * Returns this Long divided by the specified. This is an alias of {@link Long#divide}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Quotient */ LongPrototype.div = LongPrototype.divide; /** * Returns this Long modulo the specified. * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.modulo = function modulo(divisor) { if (!isLong(divisor)) divisor = fromValue(divisor); // use wasm support if present if (wasm) { var low = (this.unsigned ? wasm.rem_u : wasm.rem_s)( this.low, this.high, divisor.low, divisor.high ); return fromBits(low, wasm.get_high(), this.unsigned); } return this.sub(this.div(divisor).mul(divisor)); }; /** * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.mod = LongPrototype.modulo; /** * Returns this Long modulo the specified. This is an alias of {@link Long#modulo}. * @function * @param {!Long|number|string} divisor Divisor * @returns {!Long} Remainder */ LongPrototype.rem = LongPrototype.modulo; /** * Returns the bitwise NOT of this Long. * @returns {!Long} */ LongPrototype.not = function not() { return fromBits(~this.low, ~this.high, this.unsigned); }; /** * Returns the bitwise AND of this Long and the specified. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.and = function and(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low & other.low, this.high & other.high, this.unsigned); }; /** * Returns the bitwise OR of this Long and the specified. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.or = function or(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low | other.low, this.high | other.high, this.unsigned); }; /** * Returns the bitwise XOR of this Long and the given one. * @param {!Long|number|string} other Other Long * @returns {!Long} */ LongPrototype.xor = function xor(other) { if (!isLong(other)) other = fromValue(other); return fromBits(this.low ^ other.low, this.high ^ other.high, this.unsigned); }; /** * Returns this Long with bits shifted to the left by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftLeft = function shiftLeft(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits(this.low << numBits, (this.high << numBits) | (this.low >>> (32 - numBits)), this.unsigned); else return fromBits(0, this.low << (numBits - 32), this.unsigned); }; /** * Returns this Long with bits shifted to the left by the given amount. This is an alias of {@link Long#shiftLeft}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shl = LongPrototype.shiftLeft; /** * Returns this Long with bits arithmetically shifted to the right by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftRight = function shiftRight(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); if ((numBits &= 63) === 0) return this; else if (numBits < 32) return fromBits((this.low >>> numBits) | (this.high << (32 - numBits)), this.high >> numBits, this.unsigned); else return fromBits(this.high >> (numBits - 32), this.high >= 0 ? 0 : -1, this.unsigned); }; /** * Returns this Long with bits arithmetically shifted to the right by the given amount. This is an alias of {@link Long#shiftRight}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shr = LongPrototype.shiftRight; /** * Returns this Long with bits logically shifted to the right by the given amount. * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shiftRightUnsigned = function shiftRightUnsigned(numBits) { if (isLong(numBits)) numBits = numBits.toInt(); numBits &= 63; if (numBits === 0) return this; else { var high = this.high; if (numBits < 32) { var low = this.low; return fromBits((low >>> numBits) | (high << (32 - numBits)), high >>> numBits, this.unsigned); } else if (numBits === 32) return fromBits(high, 0, this.unsigned); else return fromBits(high >>> (numBits - 32), 0, this.unsigned); } }; /** * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shru = LongPrototype.shiftRightUnsigned; /** * Returns this Long with bits logically shifted to the right by the given amount. This is an alias of {@link Long#shiftRightUnsigned}. * @function * @param {number|!Long} numBits Number of bits * @returns {!Long} Shifted Long */ LongPrototype.shr_u = LongPrototype.shiftRightUnsigned; /** * Converts this Long to signed. * @returns {!Long} Signed long */ LongPrototype.toSigned = function toSigned() { if (!this.unsigned) return this; return fromBits(this.low, this.high, false); }; /** * Converts this Long to unsigned. * @returns {!Long} Unsigned long */ LongPrototype.toUnsigned = function toUnsigned() { if (this.unsigned) return this; return fromBits(this.low, this.high, true); }; /** * Converts this Long to its byte representation. * @param {boolean=} le Whether little or big endian, defaults to big endian * @returns {!Array.} Byte representation */ LongPrototype.toBytes = function toBytes(le) { return le ? this.toBytesLE() : this.toBytesBE(); }; /** * Converts this Long to its little endian byte representation. * @returns {!Array.} Little endian byte representation */ LongPrototype.toBytesLE = function toBytesLE() { var hi = this.high, lo = this.low; return [ lo & 0xff, lo >>> 8 & 0xff, lo >>> 16 & 0xff, lo >>> 24 , hi & 0xff, hi >>> 8 & 0xff, hi >>> 16 & 0xff, hi >>> 24 ]; }; /** * Converts this Long to its big endian byte representation. * @returns {!Array.} Big endian byte representation */ LongPrototype.toBytesBE = function toBytesBE() { var hi = this.high, lo = this.low; return [ hi >>> 24 , hi >>> 16 & 0xff, hi >>> 8 & 0xff, hi & 0xff, lo >>> 24 , lo >>> 16 & 0xff, lo >>> 8 & 0xff, lo & 0xff ]; }; /** * Creates a Long from its byte representation. * @param {!Array.} bytes Byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @param {boolean=} le Whether little or big endian, defaults to big endian * @returns {Long} The corresponding Long value */ Long.fromBytes = function fromBytes(bytes, unsigned, le) { return le ? Long.fromBytesLE(bytes, unsigned) : Long.fromBytesBE(bytes, unsigned); }; /** * Creates a Long from its little endian byte representation. * @param {!Array.} bytes Little endian byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {Long} The corresponding Long value */ Long.fromBytesLE = function fromBytesLE(bytes, unsigned) { return new Long( bytes[0] | bytes[1] << 8 | bytes[2] << 16 | bytes[3] << 24, bytes[4] | bytes[5] << 8 | bytes[6] << 16 | bytes[7] << 24, unsigned ); }; /** * Creates a Long from its big endian byte representation. * @param {!Array.} bytes Big endian byte representation * @param {boolean=} unsigned Whether unsigned or not, defaults to signed * @returns {Long} The corresponding Long value */ Long.fromBytesBE = function fromBytesBE(bytes, unsigned) { return new Long( bytes[4] << 24 | bytes[5] << 16 | bytes[6] << 8 | bytes[7], bytes[0] << 24 | bytes[1] << 16 | bytes[2] << 8 | bytes[3], unsigned ); }; var LongExports = { __proto__: null, 'default': long_1, __moduleExports: long_1 }; /** * @license * Copyright 2021 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line var Long$1 = // tslint:disable-next-line long_1 || LongExports; function hexToLong(hex) { return Long$1.fromString(hex, true, 16); } // Some primes between 2^63 and 2^64 for various uses. // Hex 0xc3a5c85c97cb3127 var k0 = hexToLong('c3a5c85c97cb3127'); // Hex 0xb492b66fbe98f273 var k1 = hexToLong('b492b66fbe98f273'); // Hex 0x9ae16a3b2f90404f var k2 = hexToLong('9ae16a3b2f90404f'); function shiftMix(val) { return val.xor(val.shru(47)); } function fetch$1(s, offset, numBytes) { var bytes = s.slice(offset, offset + numBytes); return Long$1.fromBytes(Array.from(bytes), true, true); } function fetch64(s, offset) { return fetch$1(s, offset, 8); } function fetch32(s, offset) { return fetch$1(s, offset, 4); } function rotate64(val, shift) { // Avoid shifting by 64: doing so yields an undefined result. return shift === 0 ? val : val.shru(shift).or(val.shl(64 - shift)); } function hashLen16(u, v, mul) { if (mul === void 0) { mul = hexToLong('9ddfea08eb382d69'); } // Murmur-inspired hashing. var a = u.xor(v).mul(mul); a = a.xor(a.shru(47)); var b = v.xor(a).mul(mul); b = b.xor(b.shru(47)); b = b.mul(mul); return b; } // Return a 16-byte hash for 48 bytes. Quick and dirty. // Callers do best to use "random-looking" values for a and b. function weakHashLen32WithSeeds(w, x, y, z, a, b) { a = a.add(w); b = rotate64(b.add(a).add(z), 21); var c = a; a = a.add(x); a = a.add(y); b = b.add(rotate64(a, 44)); return [a.add(z), b.add(c)]; } function weakHashLen32WithSeedsStr(s, offset, a, b) { return weakHashLen32WithSeeds(fetch64(s, offset), fetch64(s, offset + 8), fetch64(s, offset + 16), fetch64(s, offset + 24), a, b); } function hashLen0to16(s, len) { if (len === void 0) { len = s.length; } if (len >= 8) { var mul = k2.add(len * 2); var a = fetch64(s, 0).add(k2); var b = fetch64(s, len - 8); var c = rotate64(b, 37).mul(mul).add(a); var d = rotate64(a, 25).add(b).mul(mul); return hashLen16(c, d, mul); } if (len >= 4) { var mul = k2.add(len * 2); var a = fetch32(s, 0); return hashLen16(a.shl(3).add(len), fetch32(s, len - 4), mul); } if (len > 0) { var a = s[0]; var b = s[len >> 1]; var c = s[len - 1]; var y = a + (b << 8); var z = len + (c << 2); return shiftMix(k2.mul(y).xor(k0.mul(z))).mul(k2); } return k2; } function hashLen17to32(s, len) { if (len === void 0) { len = s.length; } var mul = k2.add(len * 2); var a = fetch64(s, 0).mul(k1); var b = fetch64(s, 8); var c = fetch64(s, len - 8).mul(mul); var d = fetch64(s, len - 16).mul(k2); return hashLen16(rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d), a.add(rotate64(b.add(k2), 18)).add(c), mul); } function hashLen33to64(s, len) { if (len === void 0) { len = s.length; } var mul = k2.add(len * 2); var a = fetch64(s, 0).mul(k2); var b = fetch64(s, 8); var c = fetch64(s, len - 8).mul(mul); var d = fetch64(s, len - 16).mul(k2); var y = rotate64(a.add(b), 43).add(rotate64(c, 30)).add(d); var z = hashLen16(y, a.add(rotate64(b.add(k2), 18)).add(c), mul); var e = fetch64(s, 16).mul(mul); var f = fetch64(s, 24); var g = y.add(fetch64(s, len - 32)).mul(mul); var h = z.add(fetch64(s, len - 24)).mul(mul); return hashLen16(rotate64(e.add(f), 43).add(rotate64(g, 30)).add(h), e.add(rotate64(f.add(a), 18)).add(g), mul); } function fingerPrint64(s, len) { var _a, _b; if (len === void 0) { len = s.length; } var seed = Long$1.fromNumber(81, true); if (len <= 32) { if (len <= 16) { return hashLen0to16(s, len); } else { return hashLen17to32(s, len); } } else if (len <= 64) { return hashLen33to64(s, len); } // For strings over 64 bytes we loop. Internal state consists of // 56 bytes: v, w, x, y, and z. var x = seed; var y = seed.mul(k1).add(113); var z = shiftMix(y.mul(k2).add(113)).mul(k2); var v = [Long$1.UZERO, Long$1.UZERO]; var w = [Long$1.UZERO, Long$1.UZERO]; x = x.mul(k2).add(fetch64(s, 0)); var offset = 0; // Set end so that after the loop we have 1 to 64 bytes left to process. var end = ((len - 1) >> 6) * 64; var last64 = end + ((len - 1) & 63) - 63; do { x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(k1); y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(k1); x = x.xor(w[1]); y = y.add(v[0]).add(fetch64(s, offset + 40)); z = rotate64(z.add(w[0]), 33).mul(k1); v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(k1), x.add(w[0])); w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16))); _a = [x, z], z = _a[0], x = _a[1]; offset += 64; } while (offset !== end); var mul = k1.add(z.and(0xff).shl(1)); // Point to the last 64 bytes of input. offset = last64; w[0] = w[0].add((len - 1) & 63); v[0] = v[0].add(w[0]); w[0] = w[0].add(v[0]); x = rotate64(x.add(y).add(v[0]).add(fetch64(s, offset + 8)), 37).mul(mul); y = rotate64(y.add(v[1]).add(fetch64(s, offset + 48)), 42).mul(mul); x = x.xor(w[1].mul(9)); y = y.add(v[0].mul(9).add(fetch64(s, offset + 40))); z = rotate64(z.add(w[0]), 33).mul(mul); v = weakHashLen32WithSeedsStr(s, offset, v[1].mul(mul), x.add(w[0])); w = weakHashLen32WithSeedsStr(s, offset + 32, z.add(w[1]), y.add(fetch64(s, offset + 16))); _b = [x, z], z = _b[0], x = _b[1]; return hashLen16(hashLen16(v[0], w[0], mul).add(shiftMix(y).mul(k0)).add(z), hashLen16(v[1], w[1], mul).add(x), mul); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Create typed array for scalar value. Used for storing in `DataStorage`. */ function createScalarValue(value, dtype) { if (dtype === 'string') { return encodeString(value); } return toTypedArray([value], dtype); } function noConversionNeeded(a, dtype) { return (a instanceof Float32Array && dtype === 'float32') || (a instanceof Int32Array && dtype === 'int32') || (a instanceof Uint8Array && dtype === 'bool'); } function toTypedArray(a, dtype) { if (dtype === 'string') { throw new Error('Cannot convert a string[] to a TypedArray'); } if (Array.isArray(a)) { a = flatten(a); } if (env().getBool('DEBUG')) { checkConversionForErrors(a, dtype); } if (noConversionNeeded(a, dtype)) { return a; } if (dtype == null || dtype === 'float32' || dtype === 'complex64') { return new Float32Array(a); } else if (dtype === 'int32') { return new Int32Array(a); } else if (dtype === 'bool') { var bool = new Uint8Array(a.length); for (var i = 0; i < bool.length; ++i) { if (Math.round(a[i]) !== 0) { bool[i] = 1; } } return bool; } else { throw new Error("Unknown data type " + dtype); } } /** * Returns the current high-resolution time in milliseconds relative to an * arbitrary time in the past. It works across different platforms (node.js, * browsers). * * ```js * console.log(tf.util.now()); * ``` * * @doc {heading: 'Util', namespace: 'util'} */ function now() { return env().platform.now(); } /** * Returns a platform-specific implementation of * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). * * If `fetch` is defined on the global object (`window`, `process`, etc.), * `tf.util.fetch` returns that function. * * If not, `tf.util.fetch` returns a platform-specific solution. * * ```js * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); * // handle response * ``` * * @doc {heading: 'Util'} */ function fetch$2(path, requestInits) { return env().platform.fetch(path, requestInits); } /** * Encodes the provided string into bytes using the provided encoding scheme. * * @param s The string to encode. * @param encoding The encoding scheme. Defaults to utf-8. * * @doc {heading: 'Util'} */ function encodeString(s, encoding) { if (encoding === void 0) { encoding = 'utf-8'; } encoding = encoding || 'utf-8'; return env().platform.encode(s, encoding); } /** * Decodes the provided bytes into a string using the provided encoding scheme. * @param bytes The bytes to decode. * * @param encoding The encoding scheme. Defaults to utf-8. * * @doc {heading: 'Util'} */ function decodeString(bytes, encoding) { if (encoding === void 0) { encoding = 'utf-8'; } encoding = encoding || 'utf-8'; return env().platform.decode(bytes, encoding); } var util = { __proto__: null, createScalarValue: createScalarValue, toTypedArray: toTypedArray, now: now, fetch: fetch$2, encodeString: encodeString, decodeString: decodeString, shuffle: shuffle, shuffleCombo: shuffleCombo, clamp: clamp, nearestLargerEven: nearestLargerEven, sum: sum, randUniform: randUniform, distSquared: distSquared, assert: assert, assertShapesMatch: assertShapesMatch, assertNonNull: assertNonNull, flatten: flatten, sizeFromShape: sizeFromShape, isScalarShape: isScalarShape, arraysEqual: arraysEqual, isInt: isInt, tanh: tanh, sizeToSquarishShape: sizeToSquarishShape, createShuffledIndices: createShuffledIndices, rightPad: rightPad, repeatedTry: repeatedTry, inferFromImplicitShape: inferFromImplicitShape, parseAxisParam: parseAxisParam, squeezeShape: squeezeShape, getTypedArrayFromDType: getTypedArrayFromDType, getArrayFromDType: getArrayFromDType, checkConversionForErrors: checkConversionForErrors, isValidDtype: isValidDtype, hasEncodingLoss: hasEncodingLoss, isTypedArray: isTypedArray, bytesPerElement: bytesPerElement, bytesFromStringArray: bytesFromStringArray, isString: isString, isBoolean: isBoolean, isNumber: isNumber, inferDtype: inferDtype, isFunction: isFunction, nearestDivisor: nearestDivisor, computeStrides: computeStrides, toNestedArray: toNestedArray, makeOnesTypedArray: makeOnesTypedArray, makeZerosTypedArray: makeZerosTypedArray, makeZerosNestedTypedArray: makeZerosNestedTypedArray, assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, locToIndex: locToIndex, indexToLoc: indexToLoc, isPromise: isPromise, hexToLong: hexToLong, fingerPrint64: fingerPrint64 }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var Profiler = /** @class */ (function () { function Profiler(backendTimer, logger) { this.backendTimer = backendTimer; this.logger = logger; if (logger == null) { this.logger = new Logger(); } } Profiler.prototype.profileKernel = function (kernelName, inputs, f) { var outputs; var holdResultWrapperFn = function () { outputs = f(); }; var timer; var start = now(); if (this.backendTimer.timerAvailable()) { timer = this.backendTimer.time(holdResultWrapperFn); } else { holdResultWrapperFn(); for (var _i = 0, outputs_1 = outputs; _i < outputs_1.length; _i++) { var output = outputs_1[_i]; output.dataSync(); } timer = Promise.resolve({ kernelMs: now() - start }); } if (env().getBool('CHECK_COMPUTATION_FOR_ERRORS')) { var _loop_1 = function (i) { var output = outputs[i]; // Dangling promise here because we don't want to propagate up // asynchronicity. output.data().then(function (tensorVals) { checkComputationForErrors(tensorVals, output.dtype, kernelName); }); }; for (var i = 0; i < outputs.length; i++) { _loop_1(i); } } var kernelProfile = { kernelName: kernelName, outputs: outputs, inputs: inputs, timeMs: timer.then(function (timing) { return timing.kernelMs; }), extraInfo: timer.then(function (timing) { return timing.getExtraProfileInfo != null ? timing.getExtraProfileInfo() : ''; }) }; return kernelProfile; }; Profiler.prototype.logKernelProfile = function (kernelProfile) { var _this = this; var kernelName = kernelProfile.kernelName, outputs = kernelProfile.outputs, timeMs = kernelProfile.timeMs, inputs = kernelProfile.inputs, extraInfo = kernelProfile.extraInfo; outputs.forEach(function (result) { Promise.all([result.data(), timeMs, extraInfo]).then(function (valueContainer) { _this.logger.logKernelProfile(kernelName, result, valueContainer[0], valueContainer[1], inputs, valueContainer[2]); }); }); }; return Profiler; }()); function checkComputationForErrors(vals, dtype, kernelName) { if (dtype !== 'float32') { // Only floating point computations will generate NaN values return false; } for (var i = 0; i < vals.length; i++) { var num = vals[i]; if (isNaN(num) || !isFinite(num)) { // Throwing custom exception so behavior is testable. console.warn("Found " + num + " in the result of '" + kernelName + "'"); return true; } } return false; } var Logger = /** @class */ (function () { function Logger() { } Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) { var time = typeof timeMs === 'number' ? rightPad(timeMs + "ms", 9) : timeMs['error']; var paddedName = rightPad(name, 25); var rank = result.rank; var size = result.size; var shape = rightPad(result.shape.toString(), 14); var inputShapesDescription = ''; for (var name_1 in inputs) { var input = inputs[name_1]; if (input != null) { // The input might be a non-tensor (e.g HTMLImageElement), in which case // we claim the output shape as input shape. var inputShape = input.shape || result.shape; var inputRank = inputShape.length; inputShapesDescription += name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " "; } } console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); }; return Logger; }()); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes a list of TapeNodes that connect x to y, filtering everything else * out and preserving the order of the original tape elements. * * @param tape The tape elements to filter. * @param xs The input Tensors. * @param y The output Tensor. */ function getFilteredNodesXToY(tape, xs, y) { // Forward pass to compute all the nodes and Tensors that are transitively a // function of x. var tensorsFromX = {}; var nodesFromX = {}; for (var i = 0; i < xs.length; i++) { tensorsFromX[xs[i].id] = true; } for (var i = 0; i < tape.length; i++) { var node = tape[i]; var nodeInputs = node.inputs; for (var inputName in nodeInputs) { var input = nodeInputs[inputName]; var anyInputFromX = false; for (var j = 0; j < xs.length; j++) { if (tensorsFromX[input.id]) { node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); anyInputFromX = true; nodesFromX[node.id] = true; break; } } if (anyInputFromX) { break; } } } // Backward pass to find all of the nodes and Tensors that lead to y. var tensorsLeadToY = {}; tensorsLeadToY[y.id] = true; var nodesToY = {}; for (var i = tape.length - 1; i >= 0; i--) { var node = tape[i]; var nodeInputs = node.inputs; // If any of the outputs lead to y, mark all of the inputs as leading to y. for (var j = 0; j < node.outputs.length; j++) { if (tensorsLeadToY[node.outputs[j].id]) { for (var inputName in nodeInputs) { tensorsLeadToY[nodeInputs[inputName].id] = true; nodesToY[node.id] = true; } break; } } } // Return the paths that come from x and lead to y. var filteredTape = []; for (var i = 0; i < tape.length; i++) { var node = tape[i]; if (nodesFromX[node.id] && nodesToY[node.id]) { // Prune the inputs from the node that aren't a function of x. var prunedInputs = {}; for (var inputName in node.inputs) { var nodeInput = node.inputs[inputName]; if (tensorsFromX[nodeInput.id]) { prunedInputs[inputName] = nodeInput; } } // Copy the node and overwrite inputsAndArgs to the pruned version. var prunedNode = Object.assign({}, node); prunedNode.inputs = prunedInputs; prunedNode.outputs = node.outputs; filteredTape.push(prunedNode); } } return filteredTape; } /** * Backpropagate gradients through the filtered TapeNodes. * * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map * is mutated by this method. * @param filteredTape The filtered TapeNodes to backprop through. */ function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy, add) { var _loop_1 = function (i) { var node = filteredTape[i]; var dys = []; node.outputs.forEach(function (o) { var gradTensor = tensorAccumulatedGradientMap[o.id]; if (gradTensor != null) { dys.push(gradTensor); } else { // This particular output is not in the back-propagation subgraph, so it // does not affect the final output, thus we put null for its dy. dys.push(null); } }); if (node.gradient == null) { throw new Error("Cannot compute gradient: gradient function not found " + ("for " + node.kernelName + ".")); } // Backprop dy through this node and accumulate gradients over the inputs. var inputGradients = node.gradient(dys); var _loop_2 = function (inputName) { if (!(inputName in inputGradients)) { throw new Error("Cannot backprop through input " + inputName + ". " + ("Available gradients found: " + Object.keys(inputGradients) + ".")); } // Call the gradient function. var dx = tidy(function () { return inputGradients[inputName](); }); if (dx.dtype !== 'float32') { throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'")); } var x = node.inputs[inputName]; if (!arraysEqual(dx.shape, x.shape)) { throw new Error("Error in gradient for op " + node.kernelName + ". The gradient of input " + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + ("the shape of the input '" + x.shape + "'")); } if (tensorAccumulatedGradientMap[x.id] == null) { tensorAccumulatedGradientMap[x.id] = dx; } else { var curGradient = tensorAccumulatedGradientMap[x.id]; tensorAccumulatedGradientMap[x.id] = add(curGradient, dx); curGradient.dispose(); } }; for (var inputName in node.inputs) { _loop_2(inputName); } }; // Walk the tape backward and keep a map of Tensor to its gradient. for (var i = filteredTape.length - 1; i >= 0; i--) { _loop_1(i); } } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // Maximum number of values before we decide to show ellipsis. var FORMAT_LIMIT_NUM_VALS = 20; // Number of first and last values to show when displaying a, b,...,y, z. var FORMAT_NUM_FIRST_LAST_VALS = 3; // Number of significant digits to show. var FORMAT_NUM_SIG_DIGITS = 7; function tensorToString(vals, shape, dtype, verbose) { var strides = computeStrides(shape); var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); var rank = shape.length; var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); var lines = ['Tensor']; if (verbose) { lines.push(" dtype: " + dtype); lines.push(" rank: " + rank); lines.push(" shape: [" + shape + "]"); lines.push(" values:"); } lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n')); return lines.join('\n'); } function computeMaxSizePerColumn(vals, shape, dtype, strides) { var n = sizeFromShape(shape); var numCols = strides[strides.length - 1]; var padPerCol = new Array(numCols).fill(0); var rank = shape.length; var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; if (rank > 1) { for (var row = 0; row < n / numCols; row++) { var offset = row * numCols; for (var j = 0; j < numCols; j++) { padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); } } } return padPerCol; } function valToString(val, pad, dtype) { var valStr; if (Array.isArray(val)) { valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j"); } else if (isString(val)) { valStr = "'" + val + "'"; } else if (dtype === 'bool') { valStr = boolNumToString(val); } else { valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); } return rightPad(valStr, pad); } function boolNumToString(v) { return v === 0 ? 'false' : 'true'; } function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { if (isLast === void 0) { isLast = true; } var storagePerElement = dtype === 'complex64' ? 2 : 1; var size = shape[0]; var rank = shape.length; if (rank === 0) { if (dtype === 'complex64') { var complexTuple = createComplexTuples(vals); return [valToString(complexTuple[0], 0, dtype)]; } if (dtype === 'bool') { return [boolNumToString(vals[0])]; } return [vals[0].toString()]; } if (rank === 1) { if (size > FORMAT_LIMIT_NUM_VALS) { var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; var firstVals = Array.from(vals.slice(0, firstValsSize)); var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); if (dtype === 'complex64') { firstVals = createComplexTuples(firstVals); lastVals = createComplexTuples(lastVals); } return [ '[' + firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) .join(', ') + ', ..., ' + lastVals .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); }) .join(', ') + ']' ]; } var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : Array.from(vals); return [ '[' + displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) .join(', ') + ']' ]; } // The array is rank 2 or more. var subshape = shape.slice(1); var substrides = strides.slice(1); var stride = strides[0] * storagePerElement; var lines = []; if (size > FORMAT_LIMIT_NUM_VALS) { for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); } lines.push('...'); for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); } } else { for (var i = 0; i < size; i++) { var start = i * stride; var end = start + stride; lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); } } var sep = rank === 2 ? ',' : ''; lines[0] = '[' + lines[0] + sep; for (var i = 1; i < lines.length - 1; i++) { lines[i] = ' ' + lines[i] + sep; } var newLineSep = ',\n'; for (var i = 2; i < rank; i++) { newLineSep += '\n'; } lines[lines.length - 1] = ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); return lines; } function createComplexTuples(vals) { var complexTuples = []; for (var i = 0; i < vals.length; i += 2) { complexTuples.push([vals[i], vals[i + 1]]); } return complexTuples; } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * A mutable object, similar to `tf.Tensor`, that allows users to set values * at locations before converting to an immutable `tf.Tensor`. * * See `tf.buffer` for creating a tensor buffer. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ var TensorBuffer = /** @class */ (function () { function TensorBuffer(shape, dtype, values) { var _this = this; this.dtype = dtype; this.shape = shape.slice(); this.size = sizeFromShape(shape); if (values != null) { var n_1 = values.length; assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " + ("inferred by the shape '" + _this.size + "'."); }); } if (dtype === 'complex64') { throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + "a TensorBuffer for the real and imaginary parts separately and " + "call tf.complex(real, imag)."); } this.values = values || getArrayFromDType(dtype, this.size); this.strides = computeStrides(shape); } /** * Sets a value in the buffer at a given location. * * @param value The value to set. * @param locs The location indices. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.set = function (value) { var _this = this; var locs = []; for (var _i = 1; _i < arguments.length; _i++) { locs[_i - 1] = arguments[_i]; } if (locs.length === 0) { locs = [0]; } assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " + ("match the rank (" + _this.rank + ")"); }); var index = this.locToIndex(locs); this.values[index] = value; }; /** * Returns the value in the buffer at the provided location. * * @param locs The location indices. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.get = function () { var locs = []; for (var _i = 0; _i < arguments.length; _i++) { locs[_i] = arguments[_i]; } if (locs.length === 0) { locs = [0]; } var i = 0; for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) { var loc = locs_1[_a]; if (loc < 0 || loc >= this.shape[i]) { var msg = "Requested out of range element at " + locs + ". " + (" Buffer shape=" + this.shape); throw new Error(msg); } i++; } var index = locs[locs.length - 1]; for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { index += this.strides[i_1] * locs[i_1]; } return this.values[index]; }; TensorBuffer.prototype.locToIndex = function (locs) { if (this.rank === 0) { return 0; } else if (this.rank === 1) { return locs[0]; } var index = locs[locs.length - 1]; for (var i = 0; i < locs.length - 1; ++i) { index += this.strides[i] * locs[i]; } return index; }; TensorBuffer.prototype.indexToLoc = function (index) { if (this.rank === 0) { return []; } else if (this.rank === 1) { return [index]; } var locs = new Array(this.shape.length); for (var i = 0; i < locs.length - 1; ++i) { locs[i] = Math.floor(index / this.strides[i]); index -= locs[i] * this.strides[i]; } locs[locs.length - 1] = index; return locs; }; Object.defineProperty(TensorBuffer.prototype, "rank", { get: function () { return this.shape.length; }, enumerable: true, configurable: true }); /** * Creates an immutable `tf.Tensor` object from the buffer. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ TensorBuffer.prototype.toTensor = function () { return trackerFn().makeTensor(this.values, this.shape, this.dtype); }; return TensorBuffer; }()); // For tracking tensor creation and disposal. var trackerFn = null; // Used by chaining methods to call into ops. var opHandler = null; /** * An external consumer can register itself as the tensor tracker. This way * the Tensor class can notify the tracker for every tensor created and * disposed. */ function setTensorTracker(fn) { trackerFn = fn; } /** * An external consumer can register itself as the op handler. This way the * Tensor class can have chaining methods that call into ops via the op * handler. */ function setOpHandler(handler) { opHandler = handler; } /** * A `tf.Tensor` object represents an immutable, multidimensional array of * numbers that has a shape and a data type. * * For performance reasons, functions that create tensors do not necessarily * perform a copy of the data passed to them (e.g. if the data is passed as a * `Float32Array`), and changes to the data will change the tensor. This is not * a feature and is not supported. To avoid this behavior, use the tensor before * changing the input data or create a copy with `copy = tf.add(yourTensor, 0)`. * * See `tf.tensor` for details on how to create a `tf.Tensor`. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ var Tensor = /** @class */ (function () { function Tensor(shape, dtype, dataId, id) { /** Whether this tensor has been globally kept. */ this.kept = false; this.isDisposedInternal = false; this.shape = shape.slice(); this.dtype = dtype || 'float32'; this.size = sizeFromShape(shape); this.strides = computeStrides(shape); this.dataId = dataId; this.id = id; this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); } Object.defineProperty(Tensor.prototype, "rank", { get: function () { return this.shape.length; }, enumerable: true, configurable: true }); /** * Returns a promise of `tf.TensorBuffer` that holds the underlying data. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.buffer = function () { return __awaiter(this, void 0, void 0, function () { var vals; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.data()]; case 1: vals = _a.sent(); return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)]; } }); }); }; /** * Returns a `tf.TensorBuffer` that holds the underlying data. * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.bufferSync = function () { return opHandler.buffer(this.shape, this.dtype, this.dataSync()); }; /** * Returns the tensor data as a nested array. The transfer of data is done * asynchronously. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.array = function () { return __awaiter(this, void 0, void 0, function () { var vals; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.data()]; case 1: vals = _a.sent(); return [2 /*return*/, toNestedArray(this.shape, vals, this.dtype === 'complex64')]; } }); }); }; /** * Returns the tensor data as a nested array. The transfer of data is done * synchronously. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.arraySync = function () { return toNestedArray(this.shape, this.dataSync(), this.dtype === 'complex64'); }; /** * Asynchronously downloads the values from the `tf.Tensor`. Returns a * promise of `TypedArray` that resolves when the computation has finished. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.data = function () { return __awaiter(this, void 0, void 0, function () { var data, bytes; return __generator(this, function (_a) { switch (_a.label) { case 0: this.throwIfDisposed(); data = trackerFn().read(this.dataId); if (!(this.dtype === 'string')) return [3 /*break*/, 2]; return [4 /*yield*/, data]; case 1: bytes = _a.sent(); try { return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })]; } catch (_b) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } _a.label = 2; case 2: return [2 /*return*/, data]; } }); }); }; /** * Synchronously downloads the values from the `tf.Tensor`. This blocks the * UI thread until the values are ready, which can cause performance issues. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.dataSync = function () { this.throwIfDisposed(); var data = trackerFn().readSync(this.dataId); if (this.dtype === 'string') { try { return data.map(function (b) { return decodeString(b); }); } catch (_a) { throw new Error('Failed to decode the string bytes into utf-8. ' + 'To get the original bytes, call tensor.bytes().'); } } return data; }; /** Returns the underlying bytes of the tensor's data. */ Tensor.prototype.bytes = function () { return __awaiter(this, void 0, void 0, function () { var data; return __generator(this, function (_a) { switch (_a.label) { case 0: this.throwIfDisposed(); return [4 /*yield*/, trackerFn().read(this.dataId)]; case 1: data = _a.sent(); if (this.dtype === 'string') { return [2 /*return*/, data]; } else { return [2 /*return*/, new Uint8Array(data.buffer)]; } } }); }); }; /** * Disposes `tf.Tensor` from memory. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.dispose = function () { if (this.isDisposed) { return; } trackerFn().disposeTensor(this); this.isDisposedInternal = true; }; Object.defineProperty(Tensor.prototype, "isDisposed", { get: function () { return this.isDisposedInternal; }, enumerable: true, configurable: true }); Tensor.prototype.throwIfDisposed = function () { if (this.isDisposed) { throw new Error("Tensor is disposed."); } }; /** * Prints the `tf.Tensor`. See `tf.print` for details. * * @param verbose Whether to print verbose information about the tensor, * including dtype and size. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.print = function (verbose) { if (verbose === void 0) { verbose = false; } return opHandler.print(this, verbose); }; /** * Returns a copy of the tensor. See `tf.clone` for details. * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.clone = function () { this.throwIfDisposed(); return opHandler.clone(this); }; /** * Returns a human-readable description of the tensor. Useful for logging. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Tensor.prototype.toString = function (verbose) { if (verbose === void 0) { verbose = false; } var vals = this.dataSync(); return tensorToString(vals, this.shape, this.dtype, verbose); }; Tensor.prototype.cast = function (dtype) { this.throwIfDisposed(); return opHandler.cast(this, dtype); }; Tensor.prototype.variable = function (trainable, name, dtype) { if (trainable === void 0) { trainable = true; } this.throwIfDisposed(); return trackerFn().makeVariable(this, trainable, name, dtype); }; return Tensor; }()); Object.defineProperty(Tensor, Symbol.hasInstance, { value: function (instance) { // Implementation note: we should use properties of the object that will be // defined before the constructor body has finished executing (methods). // This is because when this code is transpiled by babel, babel will call // classCallCheck before the constructor body is run. // See https://github.com/tensorflow/tfjs/issues/3384 for backstory. return !!instance && instance.data != null && instance.dataSync != null && instance.throwIfDisposed != null; } }); function getGlobalTensorClass() { // Use getGlobal so that we can augment the Tensor class across package // boundaries becase the node resolution alg may result in different modules // being returned for this file depending on the path they are loaded from. return getGlobal('Tensor', function () { return Tensor; }); } // Global side effect. Cache global reference to Tensor class getGlobalTensorClass(); /** * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ var Variable = /** @class */ (function (_super) { __extends(Variable, _super); function Variable(initialValue, trainable, name, tensorId) { var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; _this.trainable = trainable; _this.name = name; return _this; } /** * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have * the same shape and dtype as the old `tf.Tensor`. * * @param newValue New tensor to be assigned to this variable. * * @doc {heading: 'Tensors', subheading: 'Classes'} */ Variable.prototype.assign = function (newValue) { if (newValue.dtype !== this.dtype) { throw new Error("dtype of the new value (" + newValue.dtype + ") and " + ("previous value (" + this.dtype + ") must match")); } if (!arraysEqual(newValue.shape, this.shape)) { throw new Error("shape of the new value (" + newValue.shape + ") and " + ("previous value (" + this.shape + ") must match")); } trackerFn().disposeTensor(this); this.dataId = newValue.dataId; trackerFn().incRef(this, null /* backend */); }; Variable.prototype.dispose = function () { trackerFn().disposeVariable(this); this.isDisposedInternal = true; }; return Variable; }(Tensor)); Object.defineProperty(Variable, Symbol.hasInstance, { value: function (instance) { return instance instanceof Tensor && instance.assign != null && instance.assign instanceof Function; } }); /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ (function (Rank) { Rank["R0"] = "R0"; Rank["R1"] = "R1"; Rank["R2"] = "R2"; Rank["R3"] = "R3"; Rank["R4"] = "R4"; Rank["R5"] = "R5"; Rank["R6"] = "R6"; })(exports.Rank || (exports.Rank = {})); // Looks for upcasting types. Used, for example, in operations with mixed dtype // inputs. var UpcastInt32AndMap; (function (UpcastInt32AndMap) { UpcastInt32AndMap["float32"] = "float32"; UpcastInt32AndMap["int32"] = "int32"; UpcastInt32AndMap["bool"] = "int32"; UpcastInt32AndMap["complex64"] = "complex64"; })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); var UpcastBoolAndMap; (function (UpcastBoolAndMap) { UpcastBoolAndMap["float32"] = "float32"; UpcastBoolAndMap["int32"] = "int32"; UpcastBoolAndMap["bool"] = "bool"; UpcastBoolAndMap["complex64"] = "complex64"; })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); var UpcastFloat32AndMap; (function (UpcastFloat32AndMap) { UpcastFloat32AndMap["float32"] = "float32"; UpcastFloat32AndMap["int32"] = "float32"; UpcastFloat32AndMap["bool"] = "float32"; UpcastFloat32AndMap["complex64"] = "complex64"; })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); var UpcastComplex64AndMap; (function (UpcastComplex64AndMap) { UpcastComplex64AndMap["float32"] = "complex64"; UpcastComplex64AndMap["int32"] = "complex64"; UpcastComplex64AndMap["bool"] = "complex64"; UpcastComplex64AndMap["complex64"] = "complex64"; })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); var upcastTypeMap = { 'float32': UpcastFloat32AndMap, 'int32': UpcastInt32AndMap, 'bool': UpcastBoolAndMap, 'complex64': UpcastComplex64AndMap }; function upcastType(typeA, typeB) { if (typeA === 'string' || typeB === 'string') { if (typeA === 'string' && typeB === 'string') { return 'string'; } throw new Error("Can not upcast " + typeA + " with " + typeB); } return upcastTypeMap[typeA][typeB]; } /** Returns the output type after summation. */ function sumOutType(type) { return upcastType(type, 'int32'); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function makeTypesMatch(a, b) { if (a.dtype === b.dtype) { return [a, b]; } var dtype = upcastType(a.dtype, b.dtype); return [a.cast(dtype), b.cast(dtype)]; } function assertTypesMatch(a, b) { assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" + (" second(" + b.dtype + ") input must match"); }); } function isTensorInList(tensor, tensorList) { return tensorList.some(function (x) { return x.id === tensor.id; }); } /** * Extracts any `Tensor`s found within the provided object. * * @param container an object that may be a `Tensor` or may directly contain * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it * is safe to pass any object here, except that `Promise`s are not * supported. * @returns An array of `Tensors` found within the passed object. If the * argument is simply a `Tensor', a list containing that `Tensor` is * returned. If the object is not a `Tensor` or does not * contain `Tensors`, an empty list is returned. */ function getTensorsInContainer(result) { var list = []; var seen = new Set(); walkTensorContainer(result, list, seen); return list; } function walkTensorContainer(container, list, seen) { if (container == null) { return; } if (container instanceof Tensor) { list.push(container); return; } if (!isIterable(container)) { return; } // Iteration over keys works also for arrays. var iterable = container; for (var k in iterable) { var val = iterable[k]; if (!seen.has(val)) { seen.add(val); walkTensorContainer(val, list, seen); } } } // tslint:disable-next-line:no-any function isIterable(obj) { return Array.isArray(obj) || typeof obj === 'object'; } var tensor_util = { __proto__: null, makeTypesMatch: makeTypesMatch, assertTypesMatch: assertTypesMatch, isTensorInList: isTensorInList, getTensorsInContainer: getTensorsInContainer }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function isRegisteredKernelInvocation(kernelInvocation) { return kernelInvocation.kernelName != null; } var EngineState = /** @class */ (function () { function EngineState() { // Public since optimizers will use it. this.registeredVariables = {}; this.nextTapeNodeId = 0; this.numBytes = 0; this.numTensors = 0; this.numStringTensors = 0; this.numDataBuffers = 0; // Number of nested tf.grad() statements when computing higher-order // gradients. E.g. `1` for first-order gradients and `2` for second-order // gradients. Used to track if the tape should be removed after a backprop. this.gradientDepth = 0; // Number of nested kernel calls. When kernel depth is greater than 1, we turn // off the tape. this.kernelDepth = 0; this.scopeStack = []; /** * Keeps track of the number of data moves during a kernel execution. We * maintain a stack since kernels can call other kernels, recursively. */ this.numDataMovesStack = []; this.nextScopeId = 0; this.tensorInfo = new WeakMap(); this.profiling = false; this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null, get kernelNames() { return Array.from(new Set(this.kernels.map(function (k) { return k.name; }))); } }; } EngineState.prototype.dispose = function () { for (var variableName in this.registeredVariables) { this.registeredVariables[variableName].dispose(); } }; return EngineState; }()); var Engine = /** @class */ (function () { function Engine(ENV) { this.ENV = ENV; this.registry = {}; this.registryFactory = {}; this.pendingBackendInitId = 0; this.state = new EngineState(); } Engine.prototype.ready = function () { return __awaiter(this, void 0, void 0, function () { var sortedBackends, i, backendName, success; return __generator(this, function (_a) { switch (_a.label) { case 0: if (this.pendingBackendInit != null) { return [2 /*return*/, this.pendingBackendInit.then(function () { })]; } if (this.backendInstance != null) { return [2 /*return*/]; } sortedBackends = this.getSortedBackends(); i = 0; _a.label = 1; case 1: if (!(i < sortedBackends.length)) return [3 /*break*/, 5]; backendName = sortedBackends[i]; return [4 /*yield*/, this.initializeBackend(backendName).success]; case 2: success = _a.sent(); if (!success) return [3 /*break*/, 4]; return [4 /*yield*/, this.setBackend(backendName)]; case 3: _a.sent(); return [2 /*return*/]; case 4: i++; return [3 /*break*/, 1]; case 5: throw new Error("Could not initialize any backends, all backend initializations " + "failed."); } }); }); }; Object.defineProperty(Engine.prototype, "backend", { get: function () { if (this.pendingBackendInit != null) { throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " + "sure to await tf.ready() or await tf.setBackend() before calling " + "other methods"); } if (this.backendInstance == null) { var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit; if (asyncInit) { throw new Error("The highest priority backend '" + name_1 + "' has not yet been " + "initialized. Make sure to await tf.ready() or " + "await tf.setBackend() before calling other methods"); } this.setBackend(name_1); } return this.backendInstance; }, enumerable: true, configurable: true }); Engine.prototype.backendNames = function () { return Object.keys(this.registryFactory); }; Engine.prototype.findBackend = function (backendName) { if (!(backendName in this.registry)) { // If the backend hasn't been initialized but we have a registry entry for // it, initialize it and return it. if (backendName in this.registryFactory) { var asyncInit = this.initializeBackend(backendName).asyncInit; if (asyncInit) { // Backend is not ready yet. return null; } } else { return null; } } return this.registry[backendName]; }; Engine.prototype.findBackendFactory = function (backendName) { if (!(backendName in this.registryFactory)) { return null; } return this.registryFactory[backendName].factory; }; Engine.prototype.registerBackend = function (backendName, factory, priority) { if (priority === void 0) { priority = 1; } if (backendName in this.registryFactory) { console.warn(backendName + " backend was already registered. " + "Reusing existing backend factory."); return false; } this.registryFactory[backendName] = { factory: factory, priority: priority }; return true; }; Engine.prototype.setBackend = function (backendName) { return __awaiter(this, void 0, void 0, function () { var _a, success, asyncInit, result, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (this.registryFactory[backendName] == null) { throw new Error("Backend name '" + backendName + "' not found in registry"); } this.backendName = backendName; if (!(this.registry[backendName] == null)) return [3 /*break*/, 4]; this.backendInstance = null; _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; if (!asyncInit) return [3 /*break*/, 2]; return [4 /*yield*/, success]; case 1: _b = _c.sent(); return [3 /*break*/, 3]; case 2: _b = success; _c.label = 3; case 3: result = _b; if (!result) { return [2 /*return*/, false]; } _c.label = 4; case 4: this.backendInstance = this.registry[backendName]; this.setupRegisteredKernels(); // Reset the profiler. this.profiler = new Profiler(this.backendInstance); return [2 /*return*/, true]; } }); }); }; Engine.prototype.setupRegisteredKernels = function () { var _this = this; var kernels = getKernelsForBackend(this.backendName); kernels.forEach(function (kernel) { if (kernel.setupFunc != null) { kernel.setupFunc(_this.backendInstance); } }); }; Engine.prototype.disposeRegisteredKernels = function (backendName) { var _this = this; var kernels = getKernelsForBackend(backendName); kernels.forEach(function (kernel) { if (kernel.disposeFunc != null) { kernel.disposeFunc(_this.registry[backendName]); } }); }; /** * Initializes a backend by looking up the backend name in the factory * registry and calling the factory method. Returns a boolean representing * whether the initialization of the backend suceeded. Throws an error if * there is no backend in the factory registry. */ Engine.prototype.initializeBackend = function (backendName) { var _this = this; var registryFactoryEntry = this.registryFactory[backendName]; if (registryFactoryEntry == null) { throw new Error("Cannot initialize backend " + backendName + ", no registration found."); } try { var backend = registryFactoryEntry.factory(); /* Test if the factory returns a promise. Done in a more liberal way than previous 'Promise.resolve(backend)===backend' as we needed to account for custom Promise implementations (e.g. Angular) */ if (backend && !(backend instanceof KernelBackend) && typeof backend.then === 'function') { var promiseId_1 = ++this.pendingBackendInitId; var success = backend .then(function (backendInstance) { // Outdated promise. Another backend was set in the meantime. if (promiseId_1 < _this.pendingBackendInitId) { return false; } _this.registry[backendName] = backendInstance; _this.pendingBackendInit = null; return true; }) .catch(function (err) { // Outdated promise. Another backend was set in the meantime. if (promiseId_1 < _this.pendingBackendInitId) { return false; } _this.pendingBackendInit = null; console.warn("Initialization of backend " + backendName + " failed"); console.warn(err.stack || err.message); return false; }); this.pendingBackendInit = success; return { success: success, asyncInit: true }; } else { this.registry[backendName] = backend; return { success: true, asyncInit: false }; } } catch (err) { console.warn("Initialization of backend " + backendName + " failed"); console.warn(err.stack || err.message); return { success: false, asyncInit: false }; } }; Engine.prototype.removeBackend = function (backendName) { if (!(backendName in this.registryFactory)) { throw new Error(backendName + " backend not found in registry"); } if (this.backendName === backendName && this.pendingBackendInit != null) { // There is a pending promise of the backend we want to remove. Make it // obsolete. this.pendingBackendInitId++; } if (backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } delete this.registryFactory[backendName]; // Unset the backend if it is active. if (this.backendName === backendName) { this.pendingBackendInit = null; this.backendName = null; this.backendInstance = null; } }; Engine.prototype.getSortedBackends = function () { var _this = this; if (Object.keys(this.registryFactory).length === 0) { throw new Error('No backend found in registry.'); } return Object.keys(this.registryFactory).sort(function (a, b) { // Highest priority comes first. return _this.registryFactory[b].priority - _this.registryFactory[a].priority; }); }; Engine.prototype.initializeBackendsAndReturnBest = function () { var sortedBackends = this.getSortedBackends(); for (var i = 0; i < sortedBackends.length; i++) { var backendName = sortedBackends[i]; var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; if (asyncInit || success) { return { name: backendName, asyncInit: asyncInit }; } } throw new Error("Could not initialize any backends, all backend initializations " + "failed."); }; Engine.prototype.moveData = function (backend, dataId) { var info = this.state.tensorInfo.get(dataId); var srcBackend = info.backend; var values = this.readSync(dataId); var refCount = srcBackend.refCount(dataId); // Delete the tensor from the old backend and move it to the new // backend. srcBackend.disposeData(dataId, true); info.backend = backend; backend.move(dataId, values, info.shape, info.dtype, refCount); if (this.shouldCheckForMemLeaks()) { // Track the number of moves during a kernel execution to correctly // detect memory leaks. this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; } }; Engine.prototype.tidy = function (nameOrFn, fn) { var _this = this; var name = null; if (fn == null) { // Called with only 1 argument. if (typeof nameOrFn !== 'function') { throw new Error('Please provide a function to tidy()'); } fn = nameOrFn; } else { // Called with 2 arguments. if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { throw new Error('When calling with two arguments, the first argument ' + 'to tidy() must be a string'); } if (typeof fn !== 'function') { throw new Error('When calling with two arguments, the 2nd argument ' + 'to tidy() must be a function'); } name = nameOrFn; // TODO(nsthorat,smilkov): Do operation logging and performance // profiling. } var result; return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () { result = fn(); if (result instanceof Promise) { console.error('Cannot return a Promise inside of tidy.'); } return result; }); }; Engine.prototype.scopedRun = function (start, end, f) { start(); try { var res = f(); end(); return res; } catch (ex) { end(); throw ex; } }; Engine.prototype.nextTensorId = function () { return Engine.nextTensorId++; }; Engine.prototype.nextVariableId = function () { return Engine.nextVariableId++; }; /** * This method is called instead of the public-facing tensor.clone() when * saving a tensor for backwards pass. It makes sure to add the clone * operation to the tape regardless of being called inside a kernel * execution. */ Engine.prototype.clone = function (x) { var y = ENGINE.runKernel(Identity, { x: x }); var inputs = { x: x }; var grad = function (dy) { return ({ x: function () { var dtype = 'float32'; var gradInputs = { x: dy }; var attrs = { dtype: dtype }; return ENGINE.runKernel(Cast, gradInputs, // tslint:disable-next-line: no-unnecessary-type-assertion attrs); } }); }; var saved = []; this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved, {}); return y; }; /** * Execute a kernel with the given name and return the output tensor. * * @param kernelName The name of the kernel to execute. * @param inputs A map of input names to tensors. * @param attrs A map of attribute names to their values. An attribute is a * primitive (non-tensor) input to the kernel. * @param inputsToSave A list of tensors, inputs to save for the backprop * computation. * @param outputsToSave A list of booleans, specifying which output to save * for the backprop computation. These are booleans since the output * tensors are not visible to the user. */ Engine.prototype.runKernel = function (kernelName, inputs, attrs) { var hasKernel = getKernel(kernelName, this.backendName) != null; if (!hasKernel) { throw new Error("Kernel '" + kernelName + "' not registered for backend '" + this.backendName + "'"); } return this.runKernelFunc({ kernelName: kernelName, inputs: inputs, attrs: attrs }); }; Engine.prototype.shouldCheckForMemLeaks = function () { return this.ENV.getBool('IS_TEST'); }; Engine.prototype.checkKernelForMemLeak = function (kernelName, numDataIdsBefore, outInfos) { var numDataIdsAfter = this.backend.numDataIds(); // Count the number of data ids associated with the result of the kernel. var numOutputDataIds = 0; outInfos.forEach(function (info) { // Complex numbers allocate 3 data ids, one for 'real', one for // 'imaginary', and one for the container that holds the former two. numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); }); // Account for the number of moves during kernel execution. A "data move" // can happen in the middle of a kernel execution, placing a new (key,value) // pair in the data storage. Since data moves have net zero effect (we // always remove the data from the old backend), we have to cancel them out // when detecting memory leaks. var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; if (dataIdsLeaked > 0) { throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + ("(" + dataIdsLeaked + " data ids) after running '" + kernelName + "'")); } }; /** * Internal helper method to execute a kernel Func * * Use `runKernel` to execute kernels from outside of engine. */ Engine.prototype.runKernelFunc = function (kernelParams) { var _this = this; var outputs; var saved = []; var isTapeOn = this.isTapeOn(); var startingBytecount = this.state.numBytes; var startingNumTensors = this.state.numTensors; if (this.shouldCheckForMemLeaks()) { this.state.numDataMovesStack.push(0); } var kernelFunc; if (this.backendName == null) { // backend has not been initialized yet (backend initialization is lazy // can be deferred until an op/ kernel is run). // The below getter has side effects that will try to initialize the // backend and set properties like this.backendName // tslint:disable-next-line: no-unused-expression this.backend; } var out; var kernelOrScopeName = isRegisteredKernelInvocation(kernelParams) ? kernelParams.kernelName : this.state.activeScope != null ? this.state.activeScope.name : ''; // Create the kernelFunc from either a registered kernel OR passed in // forward/backward functions (used by custom grad). In this context a // kernelFunc wraps a kernel implementation with some bookkeeping. if (isRegisteredKernelInvocation(kernelParams)) { var kernelName_1 = kernelParams.kernelName, inputs_1 = kernelParams.inputs, attrs_1 = kernelParams.attrs; if (this.backendName == null) { // backend has not been initialized yet (backend initialization is lazy // can be deferred until an op/ kernel is run). // The below getter has side effects that will try to initialize the // backend and set properties like this.backendName // tslint:disable-next-line: no-unused-expression this.backend; } var kernel_1 = getKernel(kernelName_1, this.backendName); assert(kernel_1 != null, function () { return "Cannot find registered kernel '" + kernelName_1 + "' for backend '" + _this.backendName + "'"; }); kernelFunc = function () { var numDataIdsBefore = _this.backend.numDataIds(); out = kernel_1.kernelFunc({ inputs: inputs_1, attrs: attrs_1, backend: _this.backend }); var outInfos = Array.isArray(out) ? out : [out]; if (_this.shouldCheckForMemLeaks()) { _this.checkKernelForMemLeak(kernelName_1, numDataIdsBefore, outInfos); } var outTensors = outInfos.map(function (outInfo) { // todo (yassogba) remove this option (Tensor) when node backend // methods have been modularized and they all return tensorInfo. // TensorInfos do not have a rank attribute. if (outInfo.rank != null) { return outInfo; } var _a = outInfo, dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype; return _this.makeTensorFromDataId(dataId, shape, dtype); }); // Save any required inputs and outputs. // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since there would be no backprop for these tensors // (which would otherwise dispose them). if (isTapeOn) { var tensorsToSave = _this.getTensorsForGradient(kernelName_1, inputs_1, outTensors); saved = _this.saveTensorsForBackwardMode(tensorsToSave); } return outTensors; }; } else { var forwardFunc_1 = kernelParams.forwardFunc; // Running a customGrad op. var saveFunc_1 = function (tensors) { // Do not save unless we are recording to the tape. Otherwise it would // cause a mem leak since we would never run backprop, which disposes // the kept tensors. if (!isTapeOn) { return; } saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); }; kernelFunc = function () { var numDataIdsBefore = _this.backend.numDataIds(); out = _this.tidy(function () { return forwardFunc_1(_this.backend, saveFunc_1); }); var outs = (Array.isArray(out) ? out : [out]); if (_this.shouldCheckForMemLeaks()) { // Scope name is used to print a more helpful error message if needed. _this.checkKernelForMemLeak(kernelOrScopeName, numDataIdsBefore, outs); } return outs; }; } // // Run the kernelFunc. Optionally profiling it. // var inputs = kernelParams.inputs, attrs = kernelParams.attrs; var backwardsFunc = isRegisteredKernelInvocation(kernelParams) ? null : kernelParams.backwardsFunc; var kernelProfile; this.scopedRun( // Stop recording to a tape when running a kernel. function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () { if (!_this.ENV.getBool('DEBUG') && !_this.state.profiling) { outputs = kernelFunc(); } else { kernelProfile = _this.profiler.profileKernel(kernelOrScopeName, inputs, function () { return kernelFunc(); }); if (_this.ENV.getBool('DEBUG')) { _this.profiler.logKernelProfile(kernelProfile); } outputs = kernelProfile.outputs; } }); if (isTapeOn) { this.addTapeNode(kernelOrScopeName, inputs, outputs, backwardsFunc, saved, attrs); } if (this.state.profiling) { this.state.activeProfile.kernels.push({ name: kernelOrScopeName, bytesAdded: this.state.numBytes - startingBytecount, totalBytesSnapshot: this.state.numBytes, tensorsAdded: this.state.numTensors - startingNumTensors, totalTensorsSnapshot: this.state.numTensors, inputShapes: Object.keys(inputs).map(function (key) { return inputs[key] != null ? inputs[key].shape : null; }), outputShapes: outputs.map(function (item) { return item.shape; }), kernelTimeMs: kernelProfile.timeMs, extraInfo: kernelProfile.extraInfo }); } return (Array.isArray(out) ? outputs : outputs[0]); }; /** * Saves tensors used in forward mode for use in backward mode. * * @param tensors the list of tensors to save. */ Engine.prototype.saveTensorsForBackwardMode = function (tensors) { var _this = this; var saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); return saved; }; /** * Returns a list of tensors to save for a given gradient calculation. * * @param kernelName name of kernel to look up gradient for. * @param inputs a map of input tensors. * @param outputs an array of output tensors from forward mode of kernel. */ Engine.prototype.getTensorsForGradient = function (kernelName, inputs, outputs) { var gradConfig = getGradient(kernelName); if (gradConfig != null) { var inputsToSave = gradConfig.inputsToSave || []; var outputsToSave_1 = gradConfig.outputsToSave || []; // If saveAllInputs is true, all inputs will be saved. Otherwise, inputs // specified in inputsToSave will be saved. var inputTensorsToSave = void 0; if (gradConfig.saveAllInputs) { assert(Array.isArray(inputs), function () { return 'saveAllInputs is true, expected inputs to be an array.'; }); inputTensorsToSave = Object.keys(inputs).map(function (key) { return inputs[key]; }); } else { inputTensorsToSave = inputsToSave.map(function (inputName) { return inputs[inputName]; }); } var outputTensorsToSave = outputs.filter(function (_, i) { return outputsToSave_1[i]; }); return inputTensorsToSave.concat(outputTensorsToSave); } // We return an empty list rather than throw an error because the kernel we // are looking up may not actually be relevant to backproping through the // overall function // // See 'does not error if irrelevant (pruned) ops are missing grads' test // in gradients_test.ts for an example. return []; }; /** * Internal method used by public APIs for tensor creation. Makes a new * tensor with the provided shape, dtype and values. It always * creates a new data id and writes the values to the underlying backend. */ Engine.prototype.makeTensor = function (values, shape, dtype, backend) { if (values == null) { throw new Error('Values passed to engine.makeTensor() are null'); } dtype = dtype || 'float32'; backend = backend || this.backend; var backendVals = values; if (dtype === 'string' && isString(values[0])) { backendVals = values.map(function (d) { return encodeString(d); }); } var dataId = backend.write(backendVals, shape, dtype); var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); // Count bytes for string tensors. if (dtype === 'string') { var info = this.state.tensorInfo.get(dataId); var newBytes = bytesFromStringArray(backendVals); this.state.numBytes += newBytes - info.bytes; info.bytes = newBytes; } return t; }; /** * Internal method used by backends. Makes a new tensor * that is a wrapper around an existing data id. It doesn't create * a new data id, only increments the ref count used in memory tracking. */ Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) { dtype = dtype || 'float32'; var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); this.trackTensor(t, backend); return t; }; Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) { if (trainable === void 0) { trainable = true; } name = name || this.nextVariableId().toString(); if (dtype != null && dtype !== initialValue.dtype) { initialValue = initialValue.cast(dtype); } var v = new Variable(initialValue, trainable, name, this.nextTensorId()); if (this.state.registeredVariables[v.name] != null) { throw new Error("Variable with name " + v.name + " was already registered"); } this.state.registeredVariables[v.name] = v; this.incRef(v, this.backend); return v; }; Engine.prototype.trackTensor = function (a, backend) { this.state.numTensors++; if (a.dtype === 'string') { this.state.numStringTensors++; } // Bytes for complex numbers are counted by their components. Bytes for // string tensors are counted when writing values. var bytes = 0; if (a.dtype !== 'complex64' && a.dtype !== 'string') { bytes = a.size * bytesPerElement(a.dtype); } this.state.numBytes += bytes; if (!this.state.tensorInfo.has(a.dataId)) { this.state.numDataBuffers++; this.state.tensorInfo.set(a.dataId, { backend: backend || this.backend, dtype: a.dtype, shape: a.shape, bytes: bytes }); } if (!(a instanceof Variable)) { this.track(a); } }; // Track the tensor by dataId and increase the refCount for the dataId in the // backend. // TODO(pyu10055): This is currently used by makeVariable method, to increase // refCount on the backend for the dataId. It can potentially be replaced with // Identity op indead of calling backend directly. Engine.prototype.incRef = function (a, backend) { this.trackTensor(a, backend); this.backend.incRef(a.dataId); }; Engine.prototype.removeDataId = function (dataId, backend) { if (this.state.tensorInfo.has(dataId) && this.state.tensorInfo.get(dataId).backend === backend) { this.state.tensorInfo.delete(dataId); this.state.numDataBuffers--; } }; Engine.prototype.disposeTensor = function (a) { if (!this.state.tensorInfo.has(a.dataId)) { return; } var info = this.state.tensorInfo.get(a.dataId); this.state.numTensors--; if (a.dtype === 'string') { this.state.numStringTensors--; this.state.numBytes -= info.bytes; } // Don't count bytes for complex numbers as they are counted by their // components. if (a.dtype !== 'complex64' && a.dtype !== 'string') { var bytes = a.size * bytesPerElement(a.dtype); this.state.numBytes -= bytes; } // Remove the reference to dataId if backend dispose the data successfully if (info.backend.disposeData(a.dataId)) { this.removeDataId(a.dataId, info.backend); } // TODO(nsthorat): Construct an error and save the stack trace for // debugging when in debug mode. Creating a stack trace is too expensive // to do unconditionally. }; Engine.prototype.disposeVariables = function () { for (var varName in this.state.registeredVariables) { var v = this.state.registeredVariables[varName]; this.disposeVariable(v); } }; Engine.prototype.disposeVariable = function (v) { this.disposeTensor(v); if (this.state.registeredVariables[v.name] != null) { delete this.state.registeredVariables[v.name]; } }; Engine.prototype.memory = function () { var info = this.backend.memory(); info.numTensors = this.state.numTensors; info.numDataBuffers = this.state.numDataBuffers; info.numBytes = this.state.numBytes; if (this.state.numStringTensors > 0) { info.unreliable = true; if (info.reasons == null) { info.reasons = []; } info.reasons.push('Memory usage by string tensors is approximate ' + '(2 bytes per character)'); } return info; }; Engine.prototype.profile = function (query) { return __awaiter(this, void 0, void 0, function () { var startBytes, startNumTensors, _a, _i, _b, kernel, _c, _d; return __generator(this, function (_e) { switch (_e.label) { case 0: this.state.profiling = true; startBytes = this.state.numBytes; startNumTensors = this.state.numTensors; this.state.activeProfile.kernels = []; _a = this.state.activeProfile; return [4 /*yield*/, query()]; case 1: _a.result = _e.sent(); this.state.profiling = false; this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })); this.state.activeProfile.newBytes = this.state.numBytes - startBytes; this.state.activeProfile.newTensors = this.state.numTensors - startNumTensors; _i = 0, _b = this.state.activeProfile.kernels; _e.label = 2; case 2: if (!(_i < _b.length)) return [3 /*break*/, 6]; kernel = _b[_i]; _c = kernel; return [4 /*yield*/, kernel.kernelTimeMs]; case 3: _c.kernelTimeMs = _e.sent(); _d = kernel; return [4 /*yield*/, kernel.extraInfo]; case 4: _d.extraInfo = _e.sent(); _e.label = 5; case 5: _i++; return [3 /*break*/, 2]; case 6: return [2 /*return*/, this.state.activeProfile]; } }); }); }; Engine.prototype.isTapeOn = function () { return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; }; Engine.prototype.addTapeNode = function (kernelName, inputs, outputs, gradientsFunc, saved, attrs) { var _this = this; var tapeNode = { id: this.state.nextTapeNodeId++, kernelName: kernelName, inputs: inputs, outputs: outputs, saved: saved }; var gradConfig = getGradient(kernelName); if (gradConfig != null) { gradientsFunc = gradConfig.gradFunc; } if (gradientsFunc != null) { tapeNode.gradient = function (dys) { // TODO(smilkov): To optimize back-prop, pass dys that are not used in // the backprop graph to the user as null instead of zeros dys = dys.map(function (dy, i) { if (dy == null) { var output = outputs[i]; var vals = makeZerosTypedArray(output.size, output.dtype); return _this.makeTensor(vals, output.shape, output.dtype); } return dy; }); // Grad functions of ops with single outputs expect a dy, while ops // with multiple outputs expect dys (array of dy). return gradientsFunc(dys.length > 1 ? dys : dys[0], saved, attrs); }; } this.state.activeTape.push(tapeNode); }; Engine.prototype.keep = function (result) { result.kept = true; return result; }; Engine.prototype.startTape = function () { if (this.state.gradientDepth === 0) { this.state.activeTape = []; } this.state.gradientDepth++; }; Engine.prototype.endTape = function () { this.state.gradientDepth--; }; /** * Start a scope. Use this with endScope() to achieve the same functionality * as scope() without the need for a function closure. */ Engine.prototype.startScope = function (name) { var scopeInfo = { track: [], name: 'unnamed scope', id: this.state.nextScopeId++ }; if (name) { scopeInfo.name = name; } this.state.scopeStack.push(scopeInfo); this.state.activeScope = scopeInfo; }; /** * End a scope. Use this with startScope() to achieve the same functionality * as scope() without the need for a function closure. */ Engine.prototype.endScope = function (result) { var _this = this; var tensorsToTrackInParent = getTensorsInContainer(result); var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; })); // Dispose the arrays tracked in this scope. for (var i = 0; i < this.state.activeScope.track.length; i++) { var tensor = this.state.activeScope.track[i]; if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { tensor.dispose(); } } var oldScope = this.state.scopeStack.pop(); this.state.activeScope = this.state.scopeStack.length === 0 ? null : this.state.scopeStack[this.state.scopeStack.length - 1]; // Track the current result in the parent scope. tensorsToTrackInParent.forEach(function (tensor) { // Only track the tensor if was allocated in the inner scope and is not // globally kept. if (!tensor.kept && tensor.scopeId === oldScope.id) { _this.track(tensor); } }); }; /** * Returns gradients of `f` with respect to each of the `xs`. The gradients * returned are of the same length as `xs`, but some might be null if `f` * was not a function of that `x`. It also takes optional dy to multiply the * gradient, which defaults to `1`. */ Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) { var _this = this; if (allowNoGradients === void 0) { allowNoGradients = false; } assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; }); if (dy != null && dy.dtype !== 'float32') { throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'"); } var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); }); assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; }); // Filter out the nodes that don't connect x => y. var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + 'that the f you passed encloses all operations that lead from x ' + 'to y.'); } return this.tidy('backward', function () { var accumulatedGradientMap = {}; accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; // Backprop gradients through the filtered nodes. backpropagateGradients(accumulatedGradientMap, filteredTape, // Pass the tidy function to avoid circular dep with `tape.ts`. function (f) { return _this.tidy(f); }, // Pass an add function to avoide a circular dep with `tape.ts`. add); var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; }); if (_this.state.gradientDepth === 0) { // This means that we are not computing higher-order gradients // and can clean up the tape. _this.state.activeTape.forEach(function (node) { for (var _i = 0, _a = node.saved; _i < _a.length; _i++) { var tensor = _a[_i]; tensor.dispose(); } }); _this.state.activeTape = null; } return { value: y, grads: grads }; }); }; Engine.prototype.customGrad = function (f) { var _this = this; assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; }); return function () { var inputs = []; for (var _i = 0; _i < arguments.length; _i++) { inputs[_i] = arguments[_i]; } assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + 'tensors'; }); var res; var inputMap = {}; inputs.forEach(function (input, i) { inputMap[i] = input; }); var forwardFunc = function (_, save) { res = f.apply(void 0, inputs.concat([save])); assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.value` is a tensor'; }); assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function.'; }); return res.value; }; var backwardsFunc = function (dy, saved) { var gradRes = res.gradFunc(dy, saved); var grads = Array.isArray(gradRes) ? gradRes : [gradRes]; assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'the same number of tensors as inputs passed to f(...).'; }); assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' + 'object where `obj.gradFunc` is a function that returns ' + 'a list of only tensors.'; }); var gradMap = {}; grads.forEach(function (grad, i) { gradMap[i] = function () { return grad; }; }); return gradMap; }; return _this.runKernelFunc({ forwardFunc: forwardFunc, backwardsFunc: backwardsFunc, inputs: inputMap, }); }; }; Engine.prototype.readSync = function (dataId) { // Route the read to the correct backend. var info = this.state.tensorInfo.get(dataId); return info.backend.readSync(dataId); }; Engine.prototype.read = function (dataId) { // Route the read to the correct backend. var info = this.state.tensorInfo.get(dataId); return info.backend.read(dataId); }; Engine.prototype.time = function (query) { return __awaiter(this, void 0, void 0, function () { var start, timingInfo; return __generator(this, function (_a) { switch (_a.label) { case 0: start = now(); return [4 /*yield*/, this.backend.time(query)]; case 1: timingInfo = _a.sent(); timingInfo.wallMs = now() - start; return [2 /*return*/, timingInfo]; } }); }); }; /** * Tracks a Tensor in the current scope to be automatically cleaned up * when the current scope ends, and returns the value. * * @param result The Tensor to track in the current scope. */ Engine.prototype.track = function (result) { if (this.state.activeScope != null) { result.scopeId = this.state.activeScope.id; this.state.activeScope.track.push(result); } return result; }; Object.defineProperty(Engine.prototype, "registeredVariables", { get: function () { return this.state.registeredVariables; }, enumerable: true, configurable: true }); /** * Resets the engine state. Removes all backends but does not remove * registered backend factories. */ Engine.prototype.reset = function () { // Make any pending promise obsolete. this.pendingBackendInitId++; this.state.dispose(); this.ENV.reset(); this.state = new EngineState(); for (var backendName in this.registry) { this.disposeRegisteredKernels(backendName); this.registry[backendName].dispose(); delete this.registry[backendName]; } this.backendName = null; this.backendInstance = null; this.pendingBackendInit = null; }; Engine.nextTensorId = 0; Engine.nextVariableId = 0; return Engine; }()); function ones(shape) { var values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); return ENGINE.makeTensor(values, shape, 'float32'); } function getOrMakeEngine() { var ns = getGlobalNamespace(); if (ns._tfengine == null) { var environment = new Environment(ns); ns._tfengine = new Engine(environment); } setEnvironmentGlobal(ns._tfengine.ENV); // Tell the current tensor interface that the global engine is responsible // for tracking. setTensorTracker(function () { return ns._tfengine; }); return ns._tfengine; } var ENGINE = getOrMakeEngine(); /** * A implementation of the add op for use within engine and tape. * * This allows us to avoid a circular dependency between add.ts and engine. * It is exported to be available in tape tests. */ function add(a, b) { // We duplicate Add here to avoid a circular dependency with add.ts. var inputs = { a: a, b: b }; return ENGINE.runKernel(Add, inputs); } /** * @license * Copyright 2017 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // tslint:disable-next-line:no-any function _isNavigatorDefined() { return typeof navigator !== 'undefined' && navigator != null; } function isMobile(nav) { if (nav || _isNavigatorDefined()) { if (!nav) { nav = navigator; } if (nav.product === 'ReactNative') { return true; } // tslint:disable-next-line:no-any var a = nav.userAgent || nav.vendor || window.opera; // tslint:disable-next-line:max-line-length return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i .test(a) || // tslint:disable-next-line:max-line-length /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i .test(a.substr(0, 4)); } return false; } function isBrowser() { return (typeof window !== 'undefined' && window.document != null) || //@ts-ignore (typeof WorkerGlobalScope !== 'undefined'); } var device_util = { __proto__: null, isMobile: isMobile, isBrowser: isBrowser }; /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var ENV = env(); /** * This file contains environment-related flag registrations. */ /** Whether to enable debug mode. */ ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) { if (debugValue) { console.warn('Debugging mode is ON. The output of every math call will ' + 'be downloaded to CPU and checked for NaNs. ' + 'This significantly impacts performance.'); } }); /** Whether we are in a browser (as versus, say, node.js) environment. */ ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); }); /** Whether we are in a browser (as versus, say, node.js) environment. */ ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') && (typeof process.versions !== 'undefined') && (typeof process.versions.node !== 'undefined'); }); /** Whether this browser is Chrome. */ ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null && navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && /Google Inc/.test(navigator.vendor); }); /** * True when the environment is "production" where we disable safety checks * to gain performance. */ ENV.registerFlag('PROD', function () { return false; }); /** * Whether to do sanity checks when inferring a shape from user-provided * values, used when creating a new tensor. */ ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); }); /** Whether deprecation warnings are enabled. */ ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; }); /** True if running unit tests. */ ENV.registerFlag('IS_TEST', function () { return false; }); /** Whether to check computation result for errors. */ ENV.registerFlag('CHECK_COMPUTATION_FOR_ERRORS', function () { return true; }); /** Whether the backend needs to wrap input to imageBitmap. */ ENV.registerFlag('WRAP_TO_IMAGEBITMAP', function () { return false; }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ function inferShape(val, dtype) { var firstElem = val; if (isTypedArray(val)) { return dtype === 'string' ? [] : [val.length]; } if (!Array.isArray(val)) { return []; // Scalar. } var shape = []; while (Array.isArray(firstElem) || isTypedArray(firstElem) && dtype !== 'string') { shape.push(firstElem.length); firstElem = firstElem[0]; } if (Array.isArray(val) && env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { deepAssertShapeConsistency(val, shape, []); } return shape; } function deepAssertShapeConsistency(val, shape, indices) { indices = indices || []; if (!(Array.isArray(val)) && !isTypedArray(val)) { assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " + ("but should be an array/TypedArray of " + shape[0] + " elements"); }); return; } assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " + ("but is an array of " + val.length + " elements"); }); assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + ("elements, but has " + val.length + " elements"); }); var subShape = shape.slice(1); for (var i = 0; i < val.length; ++i) { deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); } } function assertDtype(expectedDtype, actualDType, argName, functionName) { if (expectedDtype === 'string_or_numeric') { return; } if (expectedDtype == null) { throw new Error("Expected dtype cannot be null."); } if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || expectedDtype === 'numeric' && actualDType === 'string') { throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); } } function convertToTensor(x, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (x instanceof Tensor) { assertDtype(parseAsDtype, x.dtype, argName, functionName); return x; } var inferredDtype = inferDtype(x); // If the user expects a bool/int/float, use that info to update the // inferredDtype when it is not a string. if (inferredDtype !== 'string' && ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { inferredDtype = parseAsDtype; } assertDtype(parseAsDtype, inferredDtype, argName, functionName); if ((x == null) || (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && typeof x !== 'boolean' && typeof x !== 'string')) { var type = x == null ? 'null' : x.constructor.name; throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + ("Tensor or TensorLike, but got '" + type + "'")); } var inferredShape = inferShape(x, inferredDtype); if (!isTypedArray(x) && !Array.isArray(x)) { x = [x]; } var skipTypedArray = true; var values = inferredDtype !== 'string' ? toTypedArray(x, inferredDtype) : flatten(x, [], skipTypedArray); return ENGINE.makeTensor(values, inferredShape, inferredDtype); } function convertToTensorArray(arg, argName, functionName, parseAsDtype) { if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } if (!Array.isArray(arg)) { throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + '`Tensor[]` or `TensorLike[]`'); } var tensors = arg; return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName, parseAsDtype); }); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var OP_SCOPE_SUFFIX = '__op'; /** * Used for wrapping functions that perform math operations on * Tensors. The function will be wrapped in a named scope that cleans all * memory usage after the function is done. */ function op(f) { var keys = Object.keys(f); if (keys.length !== 1) { throw new Error("Please provide an object with a single key " + "(operation name) mapping to a function. Got an object with " + (keys.length + " keys.")); } var opName = keys[0]; var fn = f[opName]; // Strip the underscore from the end of the function name. if (opName.endsWith('_')) { opName = opName.substring(0, opName.length - 1); } // add an __op suffix to distinguish ops from kernels in tf.profile opName = opName + OP_SCOPE_SUFFIX; // tslint:disable-next-line:no-any var f2 = function () { var args = []; for (var _i = 0; _i < arguments.length; _i++) { args[_i] = arguments[_i]; } ENGINE.startScope(opName); try { var result = fn.apply(void 0, args); if (isPromise(result)) { console.error('Cannot return a Promise inside of tidy.'); } ENGINE.endScope(result); return result; } catch (ex) { ENGINE.endScope(null); throw ex; } }; Object.defineProperty(f2, 'name', { value: opName, configurable: true }); // tslint:disable-next-line:no-any return f2; } /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Converts two real numbers to a complex number. * * Given a tensor `real` representing the real part of a complex number, and a * tensor `imag` representing the imaginary part of a complex number, this * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], * where r represents the real part and i represents the imag part. * * The input tensors real and imag must have the same shape. * * ```js * const real = tf.tensor1d([2.25, 3.25]); * const imag = tf.tensor1d([4.75, 5.75]); * const complex = tf.complex(real, imag); * * complex.print(); * ``` * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function complex_(real, imag) { var $real = convertToTensor(real, 'real', 'complex'); var $imag = convertToTensor(imag, 'imag', 'complex'); assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " + "must match in call to tf.complex()."); var inputs = { real: $real, imag: $imag }; return ENGINE.runKernel(Complex, inputs); } var complex = op({ complex_: complex_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** This is shared code across all tensor creation methods. */ function makeTensor(values, shape, inferredShape, dtype) { if (dtype == null) { dtype = inferDtype(values); } if (dtype === 'complex64') { throw new Error("Cannot construct a complex64 tensor directly. " + "Please use tf.complex(real, imag)."); } if (!isTypedArray(values) && !Array.isArray(values) && typeof values !== 'number' && typeof values !== 'boolean' && typeof values !== 'string') { throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + 'an array of numbers/booleans/strings, or a TypedArray'); } if (shape != null) { assertNonNegativeIntegerDimensions(shape); var providedSize_1 = sizeFromShape(shape); var inferredSize_1 = sizeFromShape(inferredShape); assert(providedSize_1 === inferredSize_1, function () { return "Based on the provided shape, [" + shape + "], the tensor should have " + (providedSize_1 + " values but has " + inferredSize_1); }); for (var i = 0; i < inferredShape.length; ++i) { var inferred = inferredShape[i]; var flatDimsDontMatch = i === inferredShape.length - 1 ? inferred !== sizeFromShape(shape.slice(i)) : true; assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " + ("(" + inferredShape + ") does not match the provided ") + ("shape (" + shape + "). "); }); } } if (!isTypedArray(values) && !Array.isArray(values)) { values = [values]; } shape = shape || inferredShape; values = dtype !== 'string' ? toTypedArray(values, dtype) : flatten(values, [], true); return ENGINE.makeTensor(values, shape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a `tf.Tensor` with the provided values, shape and dtype. * * ```js * // Pass an array of values to create a vector. * tf.tensor([1, 2, 3, 4]).print(); * ``` * * ```js * // Pass a nested array of values to make a matrix or a higher * // dimensional tensor. * tf.tensor([[1, 2], [3, 4]]).print(); * ``` * * ```js * // Pass a flat array and specify a shape yourself. * tf.tensor([1, 2, 3, 4], [2, 2]).print(); * ``` * * @param values The values of the tensor. Can be nested array of numbers, * or a flat array, or a `TypedArray`. If the values are strings, * they will be encoded as utf-8 and kept as `Uint8Array[]`. * @param shape The shape of the tensor. Optional. If not provided, * it is inferred from `values`. * @param dtype The data type. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor(values, shape, dtype) { var inferredShape = inferShape(values, dtype); return makeTensor(values, shape, inferredShape, dtype); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /* Type definitions for exporting and importing of models. */ /** * A map from Tensor dtype to number of bytes per element of the Tensor. */ var DTYPE_VALUE_SIZE_MAP = { 'float32': 4, 'float16': 2, 'int32': 4, 'uint16': 2, 'uint8': 1, 'bool': 1, 'complex64': 8 }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** Number of bytes reserved for the length of the string. (32bit integer). */ var NUM_BYTES_STRING_LENGTH = 4; /** * Encode a map from names to weight values as an ArrayBuffer, along with an * `Array` of `WeightsManifestEntry` as specification of the encoded weights. * * This function does not perform sharding. * * This function is the reverse of `decodeWeights`. * * @param tensors A map ("dict") from names to tensors. * @param group Group to which the weights belong (optional). * @returns A `Promise` of * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s * concatenated. * - An `Array` of `WeightManifestEntry`s, carrying information including * tensor names, `dtype`s and shapes. * @throws Error: on unsupported tensor `dtype`. */ function encodeWeights(tensors, group) { return __awaiter(this, void 0, void 0, function () { var specs, dataPromises, names, _loop_1, i, tensorValues; var _this = this; return __generator(this, function (_a) { switch (_a.label) { case 0: specs = []; dataPromises = []; names = Array.isArray(tensors) ? tensors.map(function (tensor) { return tensor.name; }) : Object.keys(tensors); _loop_1 = function (i) { var name_1 = names[i]; var t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name_1]; if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && t.dtype !== 'string' && t.dtype !== 'complex64') { throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype); } var spec = { name: name_1, shape: t.shape, dtype: t.dtype }; if (t.dtype === 'string') { var utf8bytes = new Promise(function (resolve) { return __awaiter(_this, void 0, void 0, function () { var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, t.bytes()]; case 1: vals = _a.sent(); totalNumBytes = vals.reduce(function (p, c) { return p + c.length; }, 0) + NUM_BYTES_STRING_LENGTH * vals.length; bytes = new Uint8Array(totalNumBytes); offset = 0; for (i_1 = 0; i_1 < vals.length; i_1++) { val = vals[i_1]; bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); bytes.set(bytesOfLength, offset); offset += NUM_BYTES_STRING_LENGTH; bytes.set(val, offset); offset += val.length; } resolve(bytes); return [2 /*return*/]; } }); }); }); dataPromises.push(utf8bytes); } else { dataPromises.push(t.data()); } if (group != null) { spec.group = group; } specs.push(spec); }; for (i = 0; i < names.length; ++i) { _loop_1(i); } return [4 /*yield*/, Promise.all(dataPromises)]; case 1: tensorValues = _a.sent(); return [2 /*return*/, { data: concatenateTypedArrays(tensorValues), specs: specs }]; } }); }); } /** * Decode flat ArrayBuffer as weights. * * This function does not handle sharding. * * This function is the reverse of `encodeWeights`. * * @param buffer A flat ArrayBuffer carrying the binary values of the tensors * concatenated in the order specified in `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors * whose value are encoded by `buffer`. * @return A map from tensor name to tensor value, with the names corresponding * to names in `specs`. * @throws Error, if any of the tensors has unsupported dtype. */ function decodeWeights(buffer, specs) { // TODO(adarob, cais): Support quantization. var out = {}; var float16Decode; var offset = 0; for (var _i = 0, specs_1 = specs; _i < specs_1.length; _i++) { var spec = specs_1[_i]; var name_2 = spec.name; var dtype = spec.dtype; var shape = spec.shape; var size = sizeFromShape(shape); var values = void 0; if ('quantization' in spec) { var quantization = spec.quantization; if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { if (!('min' in quantization && 'scale' in quantization)) { throw new Error("Weight " + spec.name + " with quantization " + quantization.dtype + " " + "doesn't have corresponding metadata min and scale."); } } else if (quantization.dtype === 'float16') { if (dtype !== 'float32') { throw new Error("Weight " + spec.name + " is quantized with " + quantization.dtype + " " + ("which only supports weights of type float32 not " + dtype + ".")); } } else { throw new Error("Weight " + spec.name + " has unknown " + ("quantization dtype " + quantization.dtype + ". ") + "Supported quantization dtypes are: " + "'uint8', 'uint16', and 'float16'."); } var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; var byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor); var quantizedArray = (quantization.dtype === 'uint8') ? new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === 'float32') { if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { values = new Float32Array(quantizedArray.length); for (var i = 0; i < quantizedArray.length; i++) { var v = quantizedArray[i]; values[i] = v * quantization.scale + quantization.min; } } else if (quantization.dtype === 'float16') { if (float16Decode === undefined) { float16Decode = getFloat16Decoder(); } values = float16Decode(quantizedArray); } else { throw new Error("Unsupported quantization type " + quantization.dtype + " " + "for weight type float32."); } } else if (dtype === 'int32') { if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { throw new Error("Unsupported quantization type " + quantization.dtype + " " + "for weight type int32."); } values = new Int32Array(quantizedArray.length); for (var i = 0; i < quantizedArray.length; i++) { var v = quantizedArray[i]; values[i] = Math.round(v * quantization.scale + quantization.min); } } else { throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); } offset += size * quantizationSizeFactor; } else if (dtype === 'string') { var size_1 = sizeFromShape(spec.shape); values = []; for (var i = 0; i < size_1; i++) { var byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; offset += NUM_BYTES_STRING_LENGTH; var bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); values.push(bytes); offset += byteLength; } } else { var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; var byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); if (dtype === 'float32') { values = new Float32Array(byteBuffer); } else if (dtype === 'int32') { values = new Int32Array(byteBuffer); } else if (dtype === 'bool') { values = new Uint8Array(byteBuffer); } else if (dtype === 'complex64') { values = new Float32Array(byteBuffer); var real = new Float32Array(values.length / 2); var image = new Float32Array(values.length / 2); for (var i = 0; i < real.length; i++) { real[i] = values[i * 2]; image[i] = values[i * 2 + 1]; } var realTensor = tensor(real, shape, 'float32'); var imageTensor = tensor(image, shape, 'float32'); out[name_2] = complex(realTensor, imageTensor); realTensor.dispose(); imageTensor.dispose(); } else { throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); } offset += size * dtypeFactor; } if (dtype !== 'complex64') { out[name_2] = tensor(values, shape, dtype); } } return out; } /** * Concatenate TypedArrays into an ArrayBuffer. */ function concatenateTypedArrays(xs) { // TODO(adarob, cais): Support quantization. if (xs === null) { throw new Error("Invalid input value: " + JSON.stringify(xs)); } var totalByteLength = 0; // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer' // can have a different byte length from that of the `TypedArray` itself, // for example, when the `TypedArray` is created from an offset in an // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match // the `TypedArray` in byte length. If an element of `xs` does not show // this property, a new `TypedArray` that satisfy this property will be // constructed and pushed into `normalizedXs`. var normalizedXs = []; xs.forEach(function (x) { totalByteLength += x.byteLength; // tslint:disable:no-any normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : new x.constructor(x)); if (!(x instanceof Float32Array || x instanceof Int32Array || x instanceof Uint8Array)) { throw new Error("Unsupported TypedArray subtype: " + x.constructor.name); } // tslint:enable:no-any }); var y = new Uint8Array(totalByteLength); var offset = 0; normalizedXs.forEach(function (x) { y.set(new Uint8Array(x.buffer), offset); offset += x.byteLength; }); return y.buffer; } // Use Buffer on Node.js instead of Blob/atob/btoa var useNodeBuffer = typeof Buffer !== 'undefined' && (typeof Blob === 'undefined' || typeof atob === 'undefined' || typeof btoa === 'undefined'); /** * Calculate the byte length of a JavaScript string. * * Note that a JavaScript string can contain wide characters, therefore the * length of the string is not necessarily equal to the byte length. * * @param str Input string. * @returns Byte length. */ function stringByteLength(str) { if (useNodeBuffer) { return Buffer.byteLength(str); } return new Blob([str]).size; } /** * Encode an ArrayBuffer as a base64 encoded string. * * @param buffer `ArrayBuffer` to be converted. * @returns A string that base64-encodes `buffer`. */ function arrayBufferToBase64String(buffer) { if (useNodeBuffer) { return Buffer.from(buffer).toString('base64'); } var buf = new Uint8Array(buffer); var s = ''; for (var i = 0, l = buf.length; i < l; i++) { s += String.fromCharCode(buf[i]); } return btoa(s); } /** * Decode a base64 string as an ArrayBuffer. * * @param str Base64 string. * @returns Decoded `ArrayBuffer`. */ function base64StringToArrayBuffer(str) { if (useNodeBuffer) { var buf = Buffer.from(str, 'base64'); return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); } var s = atob(str); var buffer = new Uint8Array(s.length); for (var i = 0; i < s.length; ++i) { buffer.set([s.charCodeAt(i)], i); } return buffer.buffer; } /** * Concatenate a number of ArrayBuffers into one. * * @param buffers A number of array buffers to concatenate. * @returns Result of concatenating `buffers` in order. */ function concatenateArrayBuffers(buffers) { if (buffers.length === 1) { return buffers[0]; } var totalByteLength = 0; buffers.forEach(function (buffer) { totalByteLength += buffer.byteLength; }); var temp = new Uint8Array(totalByteLength); var offset = 0; buffers.forEach(function (buffer) { temp.set(new Uint8Array(buffer), offset); offset += buffer.byteLength; }); return temp.buffer; } /** * Get the basename of a path. * * Behaves in a way analogous to Linux's basename command. * * @param path */ function basename(path) { var SEPARATOR = '/'; path = path.trim(); while (path.endsWith(SEPARATOR)) { path = path.slice(0, path.length - 1); } var items = path.split(SEPARATOR); return items[items.length - 1]; } /** * Populate ModelArtifactsInfo fields for a model with JSON topology. * @param modelArtifacts * @returns A ModelArtifactsInfo object. */ function getModelArtifactsInfoForJSON(modelArtifacts) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('Expected JSON model topology, received ArrayBuffer.'); } return { dateSaved: new Date(), modelTopologyType: 'JSON', modelTopologyBytes: modelArtifacts.modelTopology == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), weightSpecsBytes: modelArtifacts.weightSpecs == null ? 0 : stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), weightDataBytes: modelArtifacts.weightData == null ? 0 : modelArtifacts.weightData.byteLength, }; } /** * Computes mantisa table for casting Float16 to Float32 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf * * @returns Uint32Array, 2048 mantissa lookup values. */ function computeFloat16MantisaTable() { var convertMantissa = function (i) { var m = i << 13; var e = 0; while ((m & 0x00800000) === 0) { e -= 0x00800000; m <<= 1; } m &= ~0x00800000; e += 0x38800000; return m | e; }; var mantisaTable = new Uint32Array(2048); mantisaTable[0] = 0; for (var i = 1; i < 1024; i++) { mantisaTable[i] = convertMantissa(i); } for (var i = 1024; i < 2048; i++) { mantisaTable[i] = 0x38000000 + ((i - 1024) << 13); } return mantisaTable; } /** * Computes exponent table for casting Float16 to Float32 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf * * @returns Uint32Array, 64 exponent lookup values. */ function computeFloat16ExponentTable() { var exponentTable = new Uint32Array(64); exponentTable[0] = 0; exponentTable[31] = 0x47800000; exponentTable[32] = 0x80000000; exponentTable[63] = 0xc7800000; for (var i = 1; i < 31; i++) { exponentTable[i] = i << 23; } for (var i = 33; i < 63; i++) { exponentTable[i] = 0x80000000 + ((i - 32) << 23); } return exponentTable; } /** * Computes offset table for casting Float16 to Float32 * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf * * @returns Uint32Array, 6d offset values. */ function computeFloat16OffsetTable() { var offsetTable = new Uint32Array(64); for (var i = 0; i < 64; i++) { offsetTable[i] = 1024; } offsetTable[0] = offsetTable[32] = 0; return offsetTable; } /** * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values * to a Float32Array. * * @returns Function (buffer: Uint16Array) => Float32Array which decodes * the Uint16Array of Float16 bytes to a Float32Array. */ function getFloat16Decoder() { // Algorithm is based off of // http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf // Cache lookup tables var mantisaTable = computeFloat16MantisaTable(); var exponentTable = computeFloat16ExponentTable(); var offsetTable = computeFloat16OffsetTable(); return function (quantizedArray) { var buffer = new ArrayBuffer(4 * quantizedArray.length); var bufferUint32View = new Uint32Array(buffer); for (var index = 0; index < quantizedArray.length; index++) { var float16Bits = quantizedArray[index]; var float32Bits = mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + exponentTable[float16Bits >> 10]; bufferUint32View[index] = float32Bits; } return new Float32Array(buffer); }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var IORouterRegistry = /** @class */ (function () { function IORouterRegistry() { this.saveRouters = []; this.loadRouters = []; } IORouterRegistry.getInstance = function () { if (IORouterRegistry.instance == null) { IORouterRegistry.instance = new IORouterRegistry(); } return IORouterRegistry.instance; }; /** * Register a save-handler router. * * @param saveRouter A function that maps a URL-like string onto an instance * of `IOHandler` with the `save` method defined or `null`. */ IORouterRegistry.registerSaveRouter = function (saveRouter) { IORouterRegistry.getInstance().saveRouters.push(saveRouter); }; /** * Register a load-handler router. * * @param loadRouter A function that maps a URL-like string onto an instance * of `IOHandler` with the `load` method defined or `null`. */ IORouterRegistry.registerLoadRouter = function (loadRouter) { IORouterRegistry.getInstance().loadRouters.push(loadRouter); }; /** * Look up IOHandler for saving, given a URL-like string. * * @param url * @returns If only one match is found, an instance of IOHandler with the * `save` method defined. If no match is found, `null`. * @throws Error, if more than one match is found. */ IORouterRegistry.getSaveHandlers = function (url) { return IORouterRegistry.getHandlers(url, 'save'); }; /** * Look up IOHandler for loading, given a URL-like string. * * @param url * @param loadOptions Optional, custom load options. * @returns All valid handlers for `url`, given the currently registered * handler routers. */ IORouterRegistry.getLoadHandlers = function (url, loadOptions) { return IORouterRegistry.getHandlers(url, 'load', loadOptions); }; IORouterRegistry.getHandlers = function (url, handlerType, loadOptions) { var validHandlers = []; var routers = handlerType === 'load' ? IORouterRegistry.getInstance().loadRouters : IORouterRegistry.getInstance().saveRouters; routers.forEach(function (router) { var handler = router(url, loadOptions); if (handler !== null) { validHandlers.push(handler); } }); return validHandlers; }; return IORouterRegistry; }()); var registerSaveRouter = function (loudRouter) { return IORouterRegistry.registerSaveRouter(loudRouter); }; var registerLoadRouter = function (loudRouter) { return IORouterRegistry.registerLoadRouter(loudRouter); }; var getSaveHandlers = function (url) { return IORouterRegistry.getSaveHandlers(url); }; var getLoadHandlers = function (url, loadOptions) { return IORouterRegistry.getLoadHandlers(url, loadOptions); }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DATABASE_NAME = 'tensorflowjs'; var DATABASE_VERSION = 1; // Model data and ModelArtifactsInfo (metadata) are stored in two separate // stores for efficient access of the list of stored models and their metadata. // 1. The object store for model data: topology, weights and weight manifests. var MODEL_STORE_NAME = 'models_store'; // 2. The object store for ModelArtifactsInfo, including meta-information such // as the type of topology (JSON vs binary), byte size of the topology, byte // size of the weights, etc. var INFO_STORE_NAME = 'model_info_store'; function getIndexedDBFactory() { if (!env().getBool('IS_BROWSER')) { // TODO(cais): Add more info about what IOHandler subtypes are available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. throw new Error('Failed to obtain IndexedDB factory because the current environment' + 'is not a web browser.'); } // tslint:disable-next-line:no-any var theWindow = typeof window === 'undefined' ? self : window; var factory = theWindow.indexedDB || theWindow.mozIndexedDB || theWindow.webkitIndexedDB || theWindow.msIndexedDB || theWindow.shimIndexedDB; if (factory == null) { throw new Error('The current browser does not appear to support IndexedDB.'); } return factory; } function setUpDatabase(openRequest) { var db = openRequest.result; db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); } /** * IOHandler subclass: Browser IndexedDB. * * See the doc string of `browserIndexedDB` for more details. */ var BrowserIndexedDB = /** @class */ (function () { function BrowserIndexedDB(modelPath) { this.indexedDB = getIndexedDBFactory(); if (modelPath == null || !modelPath) { throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; } BrowserIndexedDB.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { // TODO(cais): Support saving GraphDef models. if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } return [2 /*return*/, this.databaseAction(this.modelPath, modelArtifacts)]; }); }); }; BrowserIndexedDB.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, this.databaseAction(this.modelPath)]; }); }); }; /** * Perform database action to put model artifacts into or read model artifacts * from IndexedDB object store. * * Whether the action is put or get depends on whether `modelArtifacts` is * specified. If it is specified, the action will be put; otherwise the action * will be get. * * @param modelPath A unique string path for the model. * @param modelArtifacts If specified, it will be the model artifacts to be * stored in IndexedDB. * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise` * of `ModelArtifacts`, if the action is get. */ BrowserIndexedDB.prototype.databaseAction = function (modelPath, modelArtifacts) { var _this = this; return new Promise(function (resolve, reject) { var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; openRequest.onsuccess = function () { var db = openRequest.result; if (modelArtifacts == null) { // Read model out from object store. var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); var modelStore = modelTx.objectStore(MODEL_STORE_NAME); var getRequest_1 = modelStore.get(_this.modelPath); getRequest_1.onsuccess = function () { if (getRequest_1.result == null) { db.close(); return reject(new Error("Cannot find model with path '" + _this.modelPath + "' " + "in IndexedDB.")); } else { resolve(getRequest_1.result.modelArtifacts); } }; getRequest_1.onerror = function (error) { db.close(); return reject(getRequest_1.error); }; modelTx.oncomplete = function () { return db.close(); }; } else { // Put model into object store. var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts); // First, put ModelArtifactsInfo into info store. var infoTx_1 = db.transaction(INFO_STORE_NAME, 'readwrite'); var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); var putInfoRequest_1 = infoStore_1.put({ modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1 }); var modelTx_1; putInfoRequest_1.onsuccess = function () { // Second, put model data into model store. modelTx_1 = db.transaction(MODEL_STORE_NAME, 'readwrite'); var modelStore = modelTx_1.objectStore(MODEL_STORE_NAME); var putModelRequest = modelStore.put({ modelPath: _this.modelPath, modelArtifacts: modelArtifacts, modelArtifactsInfo: modelArtifactsInfo_1 }); putModelRequest.onsuccess = function () { return resolve({ modelArtifactsInfo: modelArtifactsInfo_1 }); }; putModelRequest.onerror = function (error) { // If the put-model request fails, roll back the info entry as // well. infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); var deleteInfoRequest = infoStore_1.delete(_this.modelPath); deleteInfoRequest.onsuccess = function () { db.close(); return reject(putModelRequest.error); }; deleteInfoRequest.onerror = function (error) { db.close(); return reject(putModelRequest.error); }; }; }; putInfoRequest_1.onerror = function (error) { db.close(); return reject(putInfoRequest_1.error); }; infoTx_1.oncomplete = function () { if (modelTx_1 == null) { db.close(); } else { modelTx_1.oncomplete = function () { return db.close(); }; } }; } }; openRequest.onerror = function (error) { return reject(openRequest.error); }; }); }; BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; return BrowserIndexedDB; }()); var indexedDBRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(indexedDBRouter); IORouterRegistry.registerLoadRouter(indexedDBRouter); /** * Creates a browser IndexedDB IOHandler for saving and loading models. * * ```js * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); * * const saveResult = await model.save('indexeddb://MyModel')); * console.log(saveResult); * ``` * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`), * which can be used with, e.g., `tf.Model.save`. */ function browserIndexedDB(modelPath) { return new BrowserIndexedDB(modelPath); } function maybeStripScheme(key) { return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? key.slice(BrowserIndexedDB.URL_SCHEME.length) : key; } var BrowserIndexedDBManager = /** @class */ (function () { function BrowserIndexedDBManager() { this.indexedDB = getIndexedDBFactory(); } BrowserIndexedDBManager.prototype.listModels = function () { return __awaiter(this, void 0, void 0, function () { var _this = this; return __generator(this, function (_a) { return [2 /*return*/, new Promise(function (resolve, reject) { var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; openRequest.onsuccess = function () { var db = openRequest.result; var tx = db.transaction(INFO_STORE_NAME, 'readonly'); var store = tx.objectStore(INFO_STORE_NAME); // tslint:disable:max-line-length // Need to cast `store` as `any` here because TypeScript's DOM // library does not have the `getAll()` method even though the // method is supported in the latest version of most mainstream // browsers: // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll // tslint:enable:max-line-length // tslint:disable-next-line:no-any var getAllInfoRequest = store.getAll(); getAllInfoRequest.onsuccess = function () { var out = {}; for (var _i = 0, _a = getAllInfoRequest.result; _i < _a.length; _i++) { var item = _a[_i]; out[item.modelPath] = item.modelArtifactsInfo; } resolve(out); }; getAllInfoRequest.onerror = function (error) { db.close(); return reject(getAllInfoRequest.error); }; tx.oncomplete = function () { return db.close(); }; }; openRequest.onerror = function (error) { return reject(openRequest.error); }; })]; }); }); }; BrowserIndexedDBManager.prototype.removeModel = function (path) { return __awaiter(this, void 0, void 0, function () { var _this = this; return __generator(this, function (_a) { path = maybeStripScheme(path); return [2 /*return*/, new Promise(function (resolve, reject) { var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; openRequest.onsuccess = function () { var db = openRequest.result; var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); var infoStore = infoTx.objectStore(INFO_STORE_NAME); var getInfoRequest = infoStore.get(path); var modelTx; getInfoRequest.onsuccess = function () { if (getInfoRequest.result == null) { db.close(); return reject(new Error("Cannot find model with path '" + path + "' " + "in IndexedDB.")); } else { // First, delete the entry in the info store. var deleteInfoRequest = infoStore.delete(path); var deleteModelData_1 = function () { // Second, delete the entry in the model store. modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); var modelStore = modelTx.objectStore(MODEL_STORE_NAME); var deleteModelRequest = modelStore.delete(path); deleteModelRequest.onsuccess = function () { return resolve(getInfoRequest.result.modelArtifactsInfo); }; deleteModelRequest.onerror = function (error) { return reject(getInfoRequest.error); }; }; // Proceed with deleting model data regardless of whether deletion // of info data succeeds or not. deleteInfoRequest.onsuccess = deleteModelData_1; deleteInfoRequest.onerror = function (error) { deleteModelData_1(); db.close(); return reject(getInfoRequest.error); }; } }; getInfoRequest.onerror = function (error) { db.close(); return reject(getInfoRequest.error); }; infoTx.oncomplete = function () { if (modelTx == null) { db.close(); } else { modelTx.oncomplete = function () { return db.close(); }; } }; }; openRequest.onerror = function (error) { return reject(openRequest.error); }; })]; }); }); }; return BrowserIndexedDBManager; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PATH_SEPARATOR = '/'; var PATH_PREFIX = 'tensorflowjs_models'; var INFO_SUFFIX = 'info'; var MODEL_TOPOLOGY_SUFFIX = 'model_topology'; var WEIGHT_SPECS_SUFFIX = 'weight_specs'; var WEIGHT_DATA_SUFFIX = 'weight_data'; var MODEL_METADATA_SUFFIX = 'model_metadata'; function getModelKeys(path) { return { info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) }; } /** * Get model path from a local-storage key. * * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1' * * @param key */ function getModelPathFromKey(key) { var items = key.split(PATH_SEPARATOR); if (items.length < 3) { throw new Error("Invalid key format: " + key); } return items.slice(1, items.length - 1).join(PATH_SEPARATOR); } function maybeStripScheme$1(key) { return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? key.slice(BrowserLocalStorage.URL_SCHEME.length) : key; } /** * IOHandler subclass: Browser Local Storage. * * See the doc string to `browserLocalStorage` for more details. */ var BrowserLocalStorage = /** @class */ (function () { function BrowserLocalStorage(modelPath) { if (!env().getBool('IS_BROWSER') || typeof window === 'undefined' || typeof window.localStorage === 'undefined') { // TODO(cais): Add more info about what IOHandler subtypes are // available. // Maybe point to a doc page on the web and/or automatically determine // the available IOHandlers and print them in the error message. throw new Error('The current environment does not support local storage.'); } this.LS = window.localStorage; if (modelPath == null || !modelPath) { throw new Error('For local storage, modelPath must not be null, undefined or empty.'); } this.modelPath = modelPath; this.keys = getModelKeys(this.modelPath); } /** * Save model artifacts to browser local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @param modelArtifacts The model artifacts to be stored. * @returns An instance of SaveResult. */ BrowserLocalStorage.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var topology, weightSpecs, modelArtifactsInfo, result; return __generator(this, function (_a) { if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + 'in binary formats yet.'); } else { topology = JSON.stringify(modelArtifacts.modelTopology); weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); try { this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); this.LS.setItem(this.keys.topology, topology); this.LS.setItem(this.keys.weightSpecs, weightSpecs); this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); result = { format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy }; if (modelArtifacts.signature != null) { result.signature = modelArtifacts.signature; } if (modelArtifacts.userDefinedMetadata != null) { result.userDefinedMetadata = modelArtifacts.userDefinedMetadata; } if (modelArtifacts.modelInitializer != null) { result.modelInitializer = modelArtifacts.modelInitializer; } this.LS.setItem(this.keys.modelMetadata, JSON.stringify(result)); return [2 /*return*/, { modelArtifactsInfo: modelArtifactsInfo }]; } catch (err) { // If saving failed, clean up all items saved so far. this.LS.removeItem(this.keys.info); this.LS.removeItem(this.keys.topology); this.LS.removeItem(this.keys.weightSpecs); this.LS.removeItem(this.keys.weightData); this.LS.removeItem(this.keys.modelMetadata); throw new Error("Failed to save model '" + this.modelPath + "' to local storage: " + "size quota being exceeded is a possible cause of this failure: " + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + ".")); } } return [2 /*return*/]; }); }); }; /** * Load a model from local storage. * * See the documentation to `browserLocalStorage` for details on the saved * artifacts. * * @returns The loaded model (if loading succeeds). */ BrowserLocalStorage.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64; return __generator(this, function (_a) { info = JSON.parse(this.LS.getItem(this.keys.info)); if (info == null) { throw new Error("In local storage, there is no model with name '" + this.modelPath + "'"); } if (info.modelTopologyType !== 'JSON') { throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + 'topology yet.'); } out = {}; topology = JSON.parse(this.LS.getItem(this.keys.topology)); if (topology == null) { throw new Error("In local storage, the topology of model '" + this.modelPath + "' " + "is missing."); } out.modelTopology = topology; weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); if (weightSpecs == null) { throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' " + "are missing."); } out.weightSpecs = weightSpecs; metadataString = this.LS.getItem(this.keys.modelMetadata); if (metadataString != null) { metadata = JSON.parse(metadataString); out.format = metadata['format']; out.generatedBy = metadata['generatedBy']; out.convertedBy = metadata['convertedBy']; if (metadata['signature'] != null) { out.signature = metadata['signature']; } if (metadata['userDefinedMetadata'] != null) { out.userDefinedMetadata = metadata['userDefinedMetadata']; } if (metadata['modelInitializer'] != null) { out.modelInitializer = metadata['modelInitializer']; } } weightDataBase64 = this.LS.getItem(this.keys.weightData); if (weightDataBase64 == null) { throw new Error("In local storage, the binary weight values of model " + ("'" + this.modelPath + "' are missing.")); } out.weightData = base64StringToArrayBuffer(weightDataBase64); return [2 /*return*/, out]; }); }); }; BrowserLocalStorage.URL_SCHEME = 'localstorage://'; return BrowserLocalStorage; }()); var localStorageRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(localStorageRouter); IORouterRegistry.registerLoadRouter(localStorageRouter); /** * Factory function for local storage IOHandler. * * This `IOHandler` supports both `save` and `load`. * * For each model's saved artifacts, four items are saved to local storage. * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the * model, such as date saved, type of the topology, size in bytes, etc. * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- * style models, this is a stringized JSON. * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the * model, can be used to decode the saved binary weight values (see * item below). * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary * weight values, stored as a base64-encoded string. * * Saving may throw an `Error` if the total size of the artifacts exceed the * browser-specific quota. * * @param modelPath A unique identifier for the model to be saved. Must be a * non-empty string. * @returns An instance of `IOHandler`, which can be used with, e.g., * `tf.Model.save`. */ function browserLocalStorage(modelPath) { return new BrowserLocalStorage(modelPath); } var BrowserLocalStorageManager = /** @class */ (function () { function BrowserLocalStorageManager() { assert(env().getBool('IS_BROWSER'), function () { return 'Current environment is not a web browser'; }); assert(typeof window === 'undefined' || typeof window.localStorage !== 'undefined', function () { return 'Current browser does not appear to support localStorage'; }); this.LS = window.localStorage; } BrowserLocalStorageManager.prototype.listModels = function () { return __awaiter(this, void 0, void 0, function () { var out, prefix, suffix, i, key, modelPath; return __generator(this, function (_a) { out = {}; prefix = PATH_PREFIX + PATH_SEPARATOR; suffix = PATH_SEPARATOR + INFO_SUFFIX; for (i = 0; i < this.LS.length; ++i) { key = this.LS.key(i); if (key.startsWith(prefix) && key.endsWith(suffix)) { modelPath = getModelPathFromKey(key); out[modelPath] = JSON.parse(this.LS.getItem(key)); } } return [2 /*return*/, out]; }); }); }; BrowserLocalStorageManager.prototype.removeModel = function (path) { return __awaiter(this, void 0, void 0, function () { var keys, info; return __generator(this, function (_a) { path = maybeStripScheme$1(path); keys = getModelKeys(path); if (this.LS.getItem(keys.info) == null) { throw new Error("Cannot find model at path '" + path + "'"); } info = JSON.parse(this.LS.getItem(keys.info)); this.LS.removeItem(keys.info); this.LS.removeItem(keys.topology); this.LS.removeItem(keys.weightSpecs); this.LS.removeItem(keys.weightData); return [2 /*return*/, info]; }); }); }; return BrowserLocalStorageManager; }()); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var URL_SCHEME_SUFFIX = '://'; var ModelStoreManagerRegistry = /** @class */ (function () { function ModelStoreManagerRegistry() { this.managers = {}; } ModelStoreManagerRegistry.getInstance = function () { if (ModelStoreManagerRegistry.instance == null) { ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry(); } return ModelStoreManagerRegistry.instance; }; /** * Register a save-handler router. * * @param saveRouter A function that maps a URL-like string onto an instance * of `IOHandler` with the `save` method defined or `null`. */ ModelStoreManagerRegistry.registerManager = function (scheme, manager) { assert(scheme != null, function () { return 'scheme must not be undefined or null.'; }); if (scheme.endsWith(URL_SCHEME_SUFFIX)) { scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); } assert(scheme.length > 0, function () { return 'scheme must not be an empty string.'; }); var registry = ModelStoreManagerRegistry.getInstance(); assert(registry.managers[scheme] == null, function () { return "A model store manager is already registered for scheme '" + scheme + "'."; }); registry.managers[scheme] = manager; }; ModelStoreManagerRegistry.getManager = function (scheme) { var manager = this.getInstance().managers[scheme]; if (manager == null) { throw new Error("Cannot find model manager for scheme '" + scheme + "'"); } return manager; }; ModelStoreManagerRegistry.getSchemes = function () { return Object.keys(this.getInstance().managers); }; return ModelStoreManagerRegistry; }()); /** * Helper method for parsing a URL string into a scheme and a path. * * @param url E.g., 'localstorage://my-model' * @returns A dictionary with two fields: scheme and path. * Scheme: e.g., 'localstorage' in the example above. * Path: e.g., 'my-model' in the example above. */ function parseURL(url) { if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { throw new Error("The url string provided does not contain a scheme. " + "Supported schemes are: " + ("" + ModelStoreManagerRegistry.getSchemes().join(','))); } return { scheme: url.split(URL_SCHEME_SUFFIX)[0], path: url.split(URL_SCHEME_SUFFIX)[1], }; } function cloneModelInternal(sourceURL, destURL, deleteSource) { if (deleteSource === void 0) { deleteSource = false; } return __awaiter(this, void 0, void 0, function () { var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult; return __generator(this, function (_a) { switch (_a.label) { case 0: assert(sourceURL !== destURL, function () { return "Old path and new path are the same: '" + sourceURL + "'"; }); loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); assert(loadHandlers.length > 0, function () { return "Copying failed because no load handler is found for source URL " + sourceURL + "."; }); assert(loadHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + ("load handlers for source URL " + sourceURL + "."); }); loadHandler = loadHandlers[0]; saveHandlers = IORouterRegistry.getSaveHandlers(destURL); assert(saveHandlers.length > 0, function () { return "Copying failed because no save handler is found for destination " + ("URL " + destURL + "."); }); assert(saveHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + ("save handlers for destination URL " + destURL + "."); }); saveHandler = saveHandlers[0]; sourceScheme = parseURL(sourceURL).scheme; sourcePath = parseURL(sourceURL).path; sameMedium = sourceScheme === parseURL(sourceURL).scheme; return [4 /*yield*/, loadHandler.load()]; case 1: modelArtifacts = _a.sent(); if (!(deleteSource && sameMedium)) return [3 /*break*/, 3]; return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) .removeModel(sourcePath)]; case 2: _a.sent(); _a.label = 3; case 3: return [4 /*yield*/, saveHandler.save(modelArtifacts)]; case 4: saveResult = _a.sent(); if (!(deleteSource && !sameMedium)) return [3 /*break*/, 6]; return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) .removeModel(sourcePath)]; case 5: _a.sent(); _a.label = 6; case 6: return [2 /*return*/, saveResult.modelArtifactsInfo]; } }); }); } /** * List all models stored in registered storage mediums. * * For a web browser environment, the registered mediums are Local Storage and * IndexedDB. * * ```js * // First create and save a model. * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * await model.save('localstorage://demo/management/model1'); * * // Then list existing models. * console.log(JSON.stringify(await tf.io.listModels())); * * // Delete the model. * await tf.io.removeModel('localstorage://demo/management/model1'); * * // List models again. * console.log(JSON.stringify(await tf.io.listModels())); * ``` * * @returns A `Promise` of a dictionary mapping URLs of existing models to * their model artifacts info. URLs include medium-specific schemes, e.g., * 'indexeddb://my/model/1'. Model artifacts info include type of the * model's topology, byte sizes of the topology, weights, etc. * * @doc { * heading: 'Models', * subheading: 'Management', * namespace: 'io', * ignoreCI: true * } */ function listModels() { return __awaiter(this, void 0, void 0, function () { var schemes, out, _i, schemes_1, scheme, schemeOut, path, url; return __generator(this, function (_a) { switch (_a.label) { case 0: schemes = ModelStoreManagerRegistry.getSchemes(); out = {}; _i = 0, schemes_1 = schemes; _a.label = 1; case 1: if (!(_i < schemes_1.length)) return [3 /*break*/, 4]; scheme = schemes_1[_i]; return [4 /*yield*/, ModelStoreManagerRegistry.getManager(scheme).listModels()]; case 2: schemeOut = _a.sent(); for (path in schemeOut) { url = scheme + URL_SCHEME_SUFFIX + path; out[url] = schemeOut[path]; } _a.label = 3; case 3: _i++; return [3 /*break*/, 1]; case 4: return [2 /*return*/, out]; } }); }); } /** * Remove a model specified by URL from a reigstered storage medium. * * ```js * // First create and save a model. * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * await model.save('localstorage://demo/management/model1'); * * // Then list existing models. * console.log(JSON.stringify(await tf.io.listModels())); * * // Delete the model. * await tf.io.removeModel('localstorage://demo/management/model1'); * * // List models again. * console.log(JSON.stringify(await tf.io.listModels())); * ``` * * @param url A URL to a stored model, with a scheme prefix, e.g., * 'localstorage://my-model-1', 'indexeddb://my/model/2'. * @returns ModelArtifactsInfo of the deleted model (if and only if deletion * is successful). * @throws Error if deletion fails, e.g., if no model exists at `path`. * * @doc { * heading: 'Models', * subheading: 'Management', * namespace: 'io', * ignoreCI: true * } */ function removeModel(url) { return __awaiter(this, void 0, void 0, function () { var schemeAndPath, manager; return __generator(this, function (_a) { schemeAndPath = parseURL(url); manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); return [2 /*return*/, manager.removeModel(schemeAndPath.path)]; }); }); } /** * Copy a model from one URL to another. * * This function supports: * * 1. Copying within a storage medium, e.g., * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')` * 2. Copying between two storage mediums, e.g., * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')` * * ```js * // First create and save a model. * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * await model.save('localstorage://demo/management/model1'); * * // Then list existing models. * console.log(JSON.stringify(await tf.io.listModels())); * * // Copy the model, from Local Storage to IndexedDB. * await tf.io.copyModel( * 'localstorage://demo/management/model1', * 'indexeddb://demo/management/model1'); * * // List models again. * console.log(JSON.stringify(await tf.io.listModels())); * * // Remove both models. * await tf.io.removeModel('localstorage://demo/management/model1'); * await tf.io.removeModel('indexeddb://demo/management/model1'); * ``` * * @param sourceURL Source URL of copying. * @param destURL Destination URL of copying. * @returns ModelArtifactsInfo of the copied model (if and only if copying * is successful). * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or * if `oldPath` and `newPath` are identical. * * @doc { * heading: 'Models', * subheading: 'Management', * namespace: 'io', * ignoreCI: true * } */ function copyModel(sourceURL, destURL) { return __awaiter(this, void 0, void 0, function () { var deleteSource; return __generator(this, function (_a) { deleteSource = false; return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; }); }); } /** * Move a model from one URL to another. * * This function supports: * * 1. Moving within a storage medium, e.g., * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')` * 2. Moving between two storage mediums, e.g., * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')` * * ```js * // First create and save a model. * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * await model.save('localstorage://demo/management/model1'); * * // Then list existing models. * console.log(JSON.stringify(await tf.io.listModels())); * * // Move the model, from Local Storage to IndexedDB. * await tf.io.moveModel( * 'localstorage://demo/management/model1', * 'indexeddb://demo/management/model1'); * * // List models again. * console.log(JSON.stringify(await tf.io.listModels())); * * // Remove the moved model. * await tf.io.removeModel('indexeddb://demo/management/model1'); * ``` * * @param sourceURL Source URL of moving. * @param destURL Destination URL of moving. * @returns ModelArtifactsInfo of the copied model (if and only if copying * is successful). * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or * if `oldPath` and `newPath` are identical. * * @doc { * heading: 'Models', * subheading: 'Management', * namespace: 'io', * ignoreCI: true * } */ function moveModel(sourceURL, destURL) { return __awaiter(this, void 0, void 0, function () { var deleteSource; return __generator(this, function (_a) { deleteSource = true; return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; }); }); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PlatformBrowser = /** @class */ (function () { function PlatformBrowser() { } PlatformBrowser.prototype.fetch = function (path, init) { return fetch(path, init); }; PlatformBrowser.prototype.now = function () { return performance.now(); }; PlatformBrowser.prototype.encode = function (text, encoding) { if (encoding !== 'utf-8' && encoding !== 'utf8') { throw new Error("Browser's encoder only supports utf-8, but got " + encoding); } if (this.textEncoder == null) { this.textEncoder = new TextEncoder(); } return this.textEncoder.encode(text); }; PlatformBrowser.prototype.decode = function (bytes, encoding) { return new TextDecoder(encoding).decode(bytes); }; return PlatformBrowser; }()); if (env().get('IS_BROWSER')) { env().setPlatform('browser', new PlatformBrowser()); // Register LocalStorage IOHandler try { ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); } catch (err) { } // Register IndexedDB IOHandler try { ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); } catch (err) { } } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ // We are wrapping this within an object so it can be stubbed by Jasmine. var getNodeFetch = { // tslint:disable-next-line:no-require-imports importFetch: function () { return require('node-fetch'); } }; var systemFetch; var PlatformNode = /** @class */ (function () { function PlatformNode() { // tslint:disable-next-line:no-require-imports this.util = require('util'); // According to the spec, the built-in encoder can do only UTF-8 encoding. // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder this.textEncoder = new this.util.TextEncoder(); } PlatformNode.prototype.fetch = function (path, requestInits) { if (env().global.fetch != null) { return env().global.fetch(path, requestInits); } if (systemFetch == null) { systemFetch = getNodeFetch.importFetch(); } return systemFetch(path, requestInits); }; PlatformNode.prototype.now = function () { var time = process.hrtime(); return time[0] * 1000 + time[1] / 1000000; }; PlatformNode.prototype.encode = function (text, encoding) { if (encoding !== 'utf-8' && encoding !== 'utf8') { throw new Error("Node built-in encoder only supports utf-8, but got " + encoding); } return this.textEncoder.encode(text); }; PlatformNode.prototype.decode = function (bytes, encoding) { if (bytes.length === 0) { return ''; } return new this.util.TextDecoder(encoding).decode(bytes); }; return PlatformNode; }()); if (env().get('IS_NODE')) { env().setPlatform('node', new PlatformNode()); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`. * * The values are stored in CPU as `TypedArray`. Fill the buffer using * `buffer.set()`, or by modifying directly `buffer.values`. * * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with * those values. * * ```js * // Create a buffer and set values at particular indices. * const buffer = tf.buffer([2, 2]); * buffer.set(3, 0, 0); * buffer.set(5, 1, 0); * * // Convert the buffer back to a tensor. * buffer.toTensor().print(); * ``` * * @param shape An array of integers defining the output tensor shape. * @param dtype The dtype of the buffer. Defaults to 'float32'. * @param values The values of the buffer as `TypedArray`. Defaults to * zeros. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function buffer(shape, dtype, values) { if (dtype === void 0) { dtype = 'float32'; } dtype = dtype || 'float32'; assertNonNegativeIntegerDimensions(shape); return new TensorBuffer(shape, dtype, values); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Casts a `tf.Tensor` to a new dtype. * * ```js * const x = tf.tensor1d([1.5, 2.5, 3]); * tf.cast(x, 'int32').print(); * ``` * @param x The input tensor to be casted. * @param dtype The dtype to cast the input tensor to. * * @doc {heading: 'Tensors', subheading: 'Transformations'} */ function cast_(x, dtype) { var $x = convertToTensor(x, 'x', 'cast'); // Sanity checks. if (!isValidDtype(dtype)) { throw new Error("Failed to cast to unknown dtype " + dtype); } if (dtype === 'string' && $x.dtype !== 'string' || dtype !== 'string' && $x.dtype === 'string') { throw new Error('Only strings can be casted to strings'); } var inputs = { x: $x }; var attrs = { dtype: dtype }; return ENGINE.runKernel(Cast, inputs, attrs); } var cast = op({ cast_: cast_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a new tensor with the same values and shape as the specified * tensor. * * ```js * const x = tf.tensor([1, 2]); * * x.clone().print(); * ``` * * @param x The tensor to clone. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function clone_(x) { var $x = convertToTensor(x, 'x', 'clone', 'string_or_numeric'); var inputs = { x: $x }; // Note this op is called tf.identity in python. Hence the kernel name used // here. return ENGINE.runKernel(Identity, inputs); } var clone = op({ clone_: clone_ }); /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Prints information about the `tf.Tensor` including its data. * * ```js * const verbose = true; * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); * ``` * @param x The tensor to be printed. * @param verbose Whether to print verbose information about the ` Tensor`, * including dtype and size. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function print(x, verbose) { if (verbose === void 0) { verbose = false; } console.log(x.toString(verbose)); } /** * @license * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ getOrMakeEngine(); var opHandler$1 = { buffer: buffer, cast: cast, clone: clone, print: print }; setOpHandler(opHandler$1); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var DEFAULT_FILE_NAME_PREFIX = 'model'; var DEFAULT_JSON_EXTENSION_NAME = '.json'; var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; function defer(f) { return new Promise(function (resolve) { return setTimeout(resolve); }).then(f); } var BrowserDownloads = /** @class */ (function () { function BrowserDownloads(fileNamePrefix) { if (!env().getBool('IS_BROWSER')) { // TODO(cais): Provide info on what IOHandlers are available under the // current environment. throw new Error('browserDownloads() cannot proceed because the current environment ' + 'is not a browser.'); } if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); } if (fileNamePrefix == null || fileNamePrefix.length === 0) { fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; } this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; this.weightDataFileName = fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; } BrowserDownloads.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1; return __generator(this, function (_a) { switch (_a.label) { case 0: if (typeof (document) === 'undefined') { throw new Error('Browser downloads are not supported in ' + 'this environment since `document` is not present'); } weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' })); if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 1]; throw new Error('BrowserDownloads.save() does not support saving model topology ' + 'in binary formats yet.'); case 1: weightsManifest = [{ paths: ['./' + this.weightDataFileName], weights: modelArtifacts.weightSpecs }]; modelTopologyAndWeightManifest = { modelTopology: modelArtifacts.modelTopology, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, weightsManifest: weightsManifest }; if (modelArtifacts.signature != null) { modelTopologyAndWeightManifest.signature = modelArtifacts.signature; } if (modelArtifacts.userDefinedMetadata != null) { modelTopologyAndWeightManifest.userDefinedMetadata = modelArtifacts.userDefinedMetadata; } if (modelArtifacts.modelInitializer != null) { modelTopologyAndWeightManifest.modelInitializer = modelArtifacts.modelInitializer; } modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: 'application/json' })); jsonAnchor_1 = this.jsonAnchor == null ? document.createElement('a') : this.jsonAnchor; jsonAnchor_1.download = this.modelTopologyFileName; jsonAnchor_1.href = modelTopologyAndWeightManifestURL; // Trigger downloads by evoking a click event on the download anchors. // When multiple downloads are started synchronously, Firefox will only // save the last one. return [4 /*yield*/, defer(function () { return jsonAnchor_1.dispatchEvent(new MouseEvent('click')); })]; case 2: // Trigger downloads by evoking a click event on the download anchors. // When multiple downloads are started synchronously, Firefox will only // save the last one. _a.sent(); if (!(modelArtifacts.weightData != null)) return [3 /*break*/, 4]; weightDataAnchor_1 = this.weightDataAnchor == null ? document.createElement('a') : this.weightDataAnchor; weightDataAnchor_1.download = this.weightDataFileName; weightDataAnchor_1.href = weightsURL; return [4 /*yield*/, defer(function () { return weightDataAnchor_1.dispatchEvent(new MouseEvent('click')); })]; case 3: _a.sent(); _a.label = 4; case 4: return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }]; } }); }); }; BrowserDownloads.URL_SCHEME = 'downloads://'; return BrowserDownloads; }()); var BrowserFiles = /** @class */ (function () { function BrowserFiles(files) { if (files == null || files.length < 1) { throw new Error("When calling browserFiles, at least 1 file is required, " + ("but received " + files)); } this.files = files; } BrowserFiles.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { var jsonFile, weightFiles; var _this = this; return __generator(this, function (_a) { jsonFile = this.files[0]; weightFiles = this.files.slice(1); return [2 /*return*/, new Promise(function (resolve, reject) { var jsonReader = new FileReader(); jsonReader.onload = function (event) { // tslint:disable-next-line:no-any var modelJSON = JSON.parse(event.target.result); var modelTopology = modelJSON.modelTopology; if (modelTopology == null) { reject(new Error("modelTopology field is missing from file " + jsonFile.name)); return; } if (weightFiles.length === 0) { resolve({ modelTopology: modelTopology }); } var weightsManifest = modelJSON.weightsManifest; if (weightsManifest == null) { reject(new Error("weightManifest field is missing from file " + jsonFile.name)); return; } var pathToFile; try { pathToFile = _this.checkManifestAndWeightFiles(weightsManifest, weightFiles); } catch (err) { reject(err); return; } var weightSpecs = []; var paths = []; var perFileBuffers = []; weightsManifest.forEach(function (weightsGroup) { weightsGroup.paths.forEach(function (path) { paths.push(path); perFileBuffers.push(null); }); weightSpecs.push.apply(weightSpecs, weightsGroup.weights); }); weightsManifest.forEach(function (weightsGroup) { weightsGroup.paths.forEach(function (path) { var weightFileReader = new FileReader(); weightFileReader.onload = function (event) { // tslint:disable-next-line:no-any var weightData = event.target.result; var index = paths.indexOf(path); perFileBuffers[index] = weightData; if (perFileBuffers.indexOf(null) === -1) { var result = { modelTopology: modelTopology, weightSpecs: weightSpecs, weightData: concatenateArrayBuffers(perFileBuffers), format: modelJSON.format, generatedBy: modelJSON.generatedBy, convertedBy: modelJSON.convertedBy }; if (modelJSON.signature != null) { result.signature = modelJSON.signature; } if (modelJSON.userDefinedMetadata != null) { result.userDefinedMetadata = modelJSON.userDefinedMetadata; } if (modelJSON.modelInitializer != null) { result.modelInitializer = modelJSON.modelInitializer; } resolve(result); } }; weightFileReader.onerror = function (error) { return reject("Failed to weights data from file of path '" + path + "'."); }; weightFileReader.readAsArrayBuffer(pathToFile[path]); }); }); }; jsonReader.onerror = function (error) { return reject("Failed to read model topology and weights manifest JSON " + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + "Keras-style tf.Model artifacts only."); }; jsonReader.readAsText(jsonFile); })]; }); }); }; /** * Check the compatibility between weights manifest and weight files. */ BrowserFiles.prototype.checkManifestAndWeightFiles = function (manifest, files) { var basenames = []; var fileNames = files.map(function (file) { return basename(file.name); }); var pathToFile = {}; for (var _i = 0, manifest_1 = manifest; _i < manifest_1.length; _i++) { var group = manifest_1[_i]; group.paths.forEach(function (path) { var pathBasename = basename(path); if (basenames.indexOf(pathBasename) !== -1) { throw new Error("Duplicate file basename found in weights manifest: " + ("'" + pathBasename + "'")); } basenames.push(pathBasename); if (fileNames.indexOf(pathBasename) === -1) { throw new Error("Weight file with basename '" + pathBasename + "' is not provided."); } else { pathToFile[path] = files[fileNames.indexOf(pathBasename)]; } }); } if (basenames.length !== files.length) { throw new Error("Mismatch in the number of files in weights manifest " + ("(" + basenames.length + ") and the number of weight files provided ") + ("(" + files.length + ").")); } return pathToFile; }; return BrowserFiles; }()); var browserDownloadsRouter = function (url) { if (!env().getBool('IS_BROWSER')) { return null; } else { if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); } else { return null; } } }; IORouterRegistry.registerSaveRouter(browserDownloadsRouter); /** * Creates an IOHandler that triggers file downloads from the browser. * * The returned `IOHandler` instance can be used as model exporting methods such * as `tf.Model.save` and supports only saving. * * ```js * const model = tf.sequential(); * model.add(tf.layers.dense( * {units: 1, inputShape: [10], activation: 'sigmoid'})); * const saveResult = await model.save('downloads://mymodel'); * // This will trigger downloading of two files: * // 'mymodel.json' and 'mymodel.weights.bin'. * console.log(saveResult); * ``` * * @param fileNamePrefix Prefix name of the files to be downloaded. For use with * `tf.Model`, `fileNamePrefix` should follow either of the following two * formats: * 1. `null` or `undefined`, in which case the default file * names will be used: * - 'model.json' for the JSON file containing the model topology and * weights manifest. * - 'model.weights.bin' for the binary file containing the binary weight * values. * 2. A single string or an Array of a single string, as the file name prefix. * For example, if `'foo'` is provided, the downloaded JSON * file and binary weights file will be named 'foo.json' and * 'foo.weights.bin', respectively. * @param config Additional configuration for triggering downloads. * @returns An instance of `BrowserDownloads` `IOHandler`. * * @doc { * heading: 'Models', * subheading: 'Loading', * namespace: 'io', * ignoreCI: true * } */ function browserDownloads(fileNamePrefix) { if (fileNamePrefix === void 0) { fileNamePrefix = 'model'; } return new BrowserDownloads(fileNamePrefix); } /** * Creates an IOHandler that loads model artifacts from user-selected files. * * This method can be used for loading from files such as user-selected files * in the browser. * When used in conjunction with `tf.loadLayersModel`, an instance of * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. * * ```js * // Note: This code snippet won't run properly without the actual file input * // elements in the HTML DOM. * * // Suppose there are two HTML file input (``) * // elements. * const uploadJSONInput = document.getElementById('upload-json'); * const uploadWeightsInput = document.getElementById('upload-weights'); * const model = await tf.loadLayersModel(tf.io.browserFiles( * [uploadJSONInput.files[0], uploadWeightsInput.files[0]])); * ``` * * @param files `File`s to load from. Currently, this function supports only * loading from files that contain Keras-style models (i.e., `tf.Model`s), for * which an `Array` of `File`s is expected (in that order): * - A JSON file containing the model topology and weight manifest. * - Optionally, One or more binary files containing the binary weights. * These files must have names that match the paths in the `weightsManifest` * contained by the aforementioned JSON file, or errors will be thrown * during loading. These weights files have the same format as the ones * generated by `tensorflowjs_converter` that comes with the `tensorflowjs` * Python PIP package. If no weights files are provided, only the model * topology will be loaded from the JSON file above. * @returns An instance of `Files` `IOHandler`. * * @doc { * heading: 'Models', * subheading: 'Loading', * namespace: 'io', * ignoreCI: true * } */ function browserFiles(files) { return new BrowserFiles(files); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Monitor Promise.all progress, fire onProgress callback function. * * @param promises Promise list going to be monitored * @param onProgress Callback function. Fired when a promise resolved. * @param startFraction Optional fraction start. Default to 0. * @param endFraction Optional fraction end. Default to 1. */ function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { checkPromises(promises); startFraction = startFraction == null ? 0 : startFraction; endFraction = endFraction == null ? 1 : endFraction; checkFraction(startFraction, endFraction); var resolvedPromise = 0; var registerMonitor = function (promise) { promise.then(function (value) { var fraction = startFraction + ++resolvedPromise / promises.length * (endFraction - startFraction); // pass fraction as parameter to callback function. onProgress(fraction); return value; }); return promise; }; function checkPromises(promises) { assert(promises != null && Array.isArray(promises) && promises.length > 0, function () { return 'promises must be a none empty array'; }); } function checkFraction(startFraction, endFraction) { assert(startFraction >= 0 && startFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + ("got startFraction " + startFraction); }); assert(endFraction >= 0 && endFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + ("got endFraction " + endFraction); }); assert(endFraction >= startFraction, function () { return "startFraction must be no more than endFraction, but " + ("got startFraction " + startFraction + " and endFraction ") + ("" + endFraction); }); } return Promise.all(promises.map(registerMonitor)); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Reads binary weights data from a number of URLs. * * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. * @param requestOptions RequestInit (options) for the HTTP requests. * @param fetchFunc Optional overriding value for the `window.fetch` function. * @param onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same * length as `fetchURLs`. */ function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { return __awaiter(this, void 0, void 0, function () { var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b; return __generator(this, function (_c) { switch (_c.label) { case 0: if (loadOptions == null) { loadOptions = {}; } fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : loadOptions.fetchFunc; requests = fetchURLs.map(function (fetchURL) { return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); }); fetchStartFraction = 0; fetchEndFraction = 0.5; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2]; return [4 /*yield*/, Promise.all(requests)]; case 1: _a = _c.sent(); return [3 /*break*/, 4]; case 2: return [4 /*yield*/, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; case 3: _a = _c.sent(); _c.label = 4; case 4: responses = _a; bufferPromises = responses.map(function (response) { return response.arrayBuffer(); }); bufferStartFraction = 0.5; bufferEndFraction = 1; if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6]; return [4 /*yield*/, Promise.all(bufferPromises)]; case 5: _b = _c.sent(); return [3 /*break*/, 8]; case 6: return [4 /*yield*/, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; case 7: _b = _c.sent(); _c.label = 8; case 8: buffers = _b; return [2 /*return*/, buffers]; } }); }); } /** * Reads a weights manifest JSON configuration, fetches the weights and * returns them as `Tensor`s. * * @param manifest The weights manifest JSON. * @param filePathPrefix The path prefix for filenames given in the manifest. * Defaults to the empty string. * @param weightNames The names of the weights to be fetched. */ function loadWeights(manifest, filePathPrefix, weightNames, requestInit) { if (filePathPrefix === void 0) { filePathPrefix = ''; } return __awaiter(this, void 0, void 0, function () { var fetchWeights, loadWeights; return __generator(this, function (_a) { fetchWeights = function (fetchUrls) { return loadWeightsAsArrayBuffer(fetchUrls, { requestInit: requestInit }); }; loadWeights = weightsLoaderFactory(fetchWeights); return [2 /*return*/, loadWeights(manifest, filePathPrefix, weightNames)]; }); }); } /** * Creates a function, which reads a weights manifest JSON configuration, * fetches the weight files using the specified function and returns them as * `Tensor`s. * * ```js * // example for creating a nodejs weight loader, which reads the weight files * // from disk using fs.readFileSync * * import * as fs from 'fs' * * const fetchWeightsFromDisk = (filePaths: string[]) => * filePaths.map(filePath => fs.readFileSync(filePath).buffer) * * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) * * const manifest = JSON.parse( * fs.readFileSync('./my_model-weights_manifest').toString() * ) * const weightMap = await loadWeights(manifest, './') * ``` * @param fetchWeightsFunction The function used for fetching the weight files. * @returns Weight loading function. */ function weightsLoaderFactory(fetchWeightsFunction) { var _this = this; return function (manifest, filePathPrefix, weightNames) { if (filePathPrefix === void 0) { filePathPrefix = ''; } return __awaiter(_this, void 0, void 0, function () { var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset; return __generator(this, function (_a) { switch (_a.label) { case 0: groupIndicesToFetchMap = manifest.map(function () { return false; }); groupWeightsToFetch = {}; weightsFound = weightNames != null ? weightNames.map(function () { return false; }) : []; allManifestWeightNames = []; manifest.forEach(function (manifestGroupConfig, groupIndex) { var groupOffset = 0; manifestGroupConfig.weights.forEach(function (weightsEntry) { var rawDtype = ('quantization' in weightsEntry) ? weightsEntry.quantization.dtype : weightsEntry.dtype; var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * sizeFromShape(weightsEntry.shape); var enqueueWeightsForFetchingFn = function () { groupIndicesToFetchMap[groupIndex] = true; if (groupWeightsToFetch[groupIndex] == null) { groupWeightsToFetch[groupIndex] = []; } groupWeightsToFetch[groupIndex].push({ manifestEntry: weightsEntry, groupOffset: groupOffset, sizeBytes: weightsBytes }); }; if (weightNames != null) { weightNames.forEach(function (weightName, weightIndex) { if (weightName === weightsEntry.name) { enqueueWeightsForFetchingFn(); weightsFound[weightIndex] = true; } }); } else { enqueueWeightsForFetchingFn(); } allManifestWeightNames.push(weightsEntry.name); groupOffset += weightsBytes; }); }); if (!weightsFound.every(function (found) { return found; })) { weightsNotFound = weightNames.filter(function (_, i) { return !weightsFound[i]; }); throw new Error("Could not find weights in manifest with names: " + (weightsNotFound.join(', ') + ". \n") + "Manifest JSON has weights with names: " + (allManifestWeightNames.join(', ') + ".")); } groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) { if (shouldFetch) { accumulator.push(i); } return accumulator; }, []); fetchUrls = []; groupIndicesToFetch.forEach(function (i) { manifest[i].paths.forEach(function (filepath) { var fetchUrl = filePathPrefix + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; fetchUrls.push(fetchUrl); }); }); return [4 /*yield*/, fetchWeightsFunction(fetchUrls)]; case 1: buffers = _a.sent(); weightsTensorMap = {}; bufferIndexOffset = 0; groupIndicesToFetch.forEach(function (i) { var numBuffers = manifest[i].paths.length; var groupBytes = 0; for (var i_1 = 0; i_1 < numBuffers; i_1++) { groupBytes += buffers[bufferIndexOffset + i_1].byteLength; } // Create a buffer for the whole group. var groupBuffer = new ArrayBuffer(groupBytes); var groupByteBuffer = new Uint8Array(groupBuffer); var groupBufferOffset = 0; for (var i_2 = 0; i_2 < numBuffers; i_2++) { var buffer = new Uint8Array(buffers[bufferIndexOffset + i_2]); groupByteBuffer.set(buffer, groupBufferOffset); groupBufferOffset += buffer.byteLength; } var weightsEntries = groupWeightsToFetch[i]; weightsEntries.forEach(function (weightsEntry) { var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); for (var name_1 in nameToTensorMap) { weightsTensorMap[name_1] = nameToTensorMap[name_1]; } }); bufferIndexOffset += numBuffers; }); return [2 /*return*/, weightsTensorMap]; } }); }); }; } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; var JSON_TYPE = 'application/json'; var HTTPRequest = /** @class */ (function () { function HTTPRequest(path, loadOptions) { this.DEFAULT_METHOD = 'POST'; if (loadOptions == null) { loadOptions = {}; } this.weightPathPrefix = loadOptions.weightPathPrefix; this.onProgress = loadOptions.onProgress; this.weightUrlConverter = loadOptions.weightUrlConverter; if (loadOptions.fetchFunc != null) { assert(typeof loadOptions.fetchFunc === 'function', function () { return 'Must pass a function that matches the signature of ' + '`fetch` (see ' + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'; }); this.fetch = loadOptions.fetchFunc; } else { this.fetch = env().platform.fetch; } assert(path != null && path.length > 0, function () { return 'URL path for http must not be null, undefined or ' + 'empty.'; }); if (Array.isArray(path)) { assert(path.length === 2, function () { return 'URL paths for http must have a length of 2, ' + ("(actual length is " + path.length + ")."); }); } this.path = path; if (loadOptions.requestInit != null && loadOptions.requestInit.body != null) { throw new Error('requestInit is expected to have no pre-existing body, but has one.'); } this.requestInit = loadOptions.requestInit || {}; } HTTPRequest.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { var init, weightsManifest, modelTopologyAndWeightManifest, response; return __generator(this, function (_a) { switch (_a.label) { case 0: if (modelArtifacts.modelTopology instanceof ArrayBuffer) { throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + 'in binary formats yet.'); } init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit); init.body = new FormData(); weightsManifest = [{ paths: ['./model.weights.bin'], weights: modelArtifacts.weightSpecs, }]; modelTopologyAndWeightManifest = { modelTopology: modelArtifacts.modelTopology, format: modelArtifacts.format, generatedBy: modelArtifacts.generatedBy, convertedBy: modelArtifacts.convertedBy, weightsManifest: weightsManifest }; if (modelArtifacts.signature != null) { modelTopologyAndWeightManifest.signature = modelArtifacts.signature; } if (modelArtifacts.userDefinedMetadata != null) { modelTopologyAndWeightManifest.userDefinedMetadata = modelArtifacts.userDefinedMetadata; } if (modelArtifacts.modelInitializer != null) { modelTopologyAndWeightManifest.modelInitializer = modelArtifacts.modelInitializer; } init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json'); if (modelArtifacts.weightData != null) { init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin'); } return [4 /*yield*/, this.fetch(this.path, init)]; case 1: response = _a.sent(); if (response.ok) { return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), responses: [response], }]; } else { throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + (response.status + ".")); } } }); }); }; /** * Load model artifacts via HTTP request(s). * * See the documentation to `tf.io.http` for details on the saved * artifacts. * * @returns The loaded model artifacts (if loading succeeds). */ HTTPRequest.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, signature, userDefinedMetadata, weightSpecs, weightData, results, artifacts, initializer; return __generator(this, function (_a) { switch (_a.label) { case 0: return [4 /*yield*/, this.fetch(this.path, this.requestInit)]; case 1: modelConfigRequest = _a.sent(); if (!modelConfigRequest.ok) { throw new Error("Request to " + this.path + " failed with status code " + (modelConfigRequest.status + ". Please verify this URL points to ") + "the model JSON of the model to load."); } _a.label = 2; case 2: _a.trys.push([2, 4, , 5]); return [4 /*yield*/, modelConfigRequest.json()]; case 3: modelConfig = _a.sent(); return [3 /*break*/, 5]; case 4: e_1 = _a.sent(); message = "Failed to parse model JSON of response from " + this.path + "."; // TODO(nsthorat): Remove this after some time when we're comfortable that // .pb files are mostly gone. if (this.path.endsWith('.pb')) { message += ' Your path contains a .pb file extension. ' + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + 'in favor of .json models. You can re-convert your Python ' + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + 'or you can convert your.pb models with the \'pb2json\'' + 'NPM script in the tensorflow/tfjs-converter repository.'; } else { message += ' Please make sure the server is serving valid ' + 'JSON for this request.'; } throw new Error(message); case 5: modelTopology = modelConfig.modelTopology; weightsManifest = modelConfig.weightsManifest; generatedBy = modelConfig.generatedBy; convertedBy = modelConfig.convertedBy; format = modelConfig.format; signature = modelConfig.signature; userDefinedMetadata = modelConfig.userDefinedMetadata; // We do not allow both modelTopology and weightsManifest to be missing. if (modelTopology == null && weightsManifest == null) { throw new Error("The JSON from HTTP path " + this.path + " contains neither model " + "topology or manifest for weights."); } if (!(weightsManifest != null)) return [3 /*break*/, 7]; return [4 /*yield*/, this.loadWeights(weightsManifest)]; case 6: results = _a.sent(); weightSpecs = results[0], weightData = results[1]; _a.label = 7; case 7: artifacts = { modelTopology: modelTopology, weightSpecs: weightSpecs, weightData: weightData, generatedBy: generatedBy, convertedBy: convertedBy, format: format }; if (signature != null) { artifacts.signature = signature; } if (userDefinedMetadata != null) { artifacts.userDefinedMetadata = userDefinedMetadata; } initializer = modelConfig.modelInitializer; if (initializer) { artifacts.modelInitializer = initializer; } return [2 /*return*/, artifacts]; } }); }); }; HTTPRequest.prototype.loadWeights = function (weightsManifest) { return __awaiter(this, void 0, void 0, function () { var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i, weightsManifest_1, entry, fetchURLs, urlPromises, _b, weightsManifest_2, weightsGroup, _c, _d, path, _e, _f, _g, buffers; return __generator(this, function (_h) { switch (_h.label) { case 0: weightPath = Array.isArray(this.path) ? this.path[1] : this.path; _a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1]; pathPrefix = this.weightPathPrefix || prefix; weightSpecs = []; for (_i = 0, weightsManifest_1 = weightsManifest; _i < weightsManifest_1.length; _i++) { entry = weightsManifest_1[_i]; weightSpecs.push.apply(weightSpecs, entry.weights); } fetchURLs = []; urlPromises = []; for (_b = 0, weightsManifest_2 = weightsManifest; _b < weightsManifest_2.length; _b++) { weightsGroup = weightsManifest_2[_b]; for (_c = 0, _d = weightsGroup.paths; _c < _d.length; _c++) { path = _d[_c]; if (this.weightUrlConverter != null) { urlPromises.push(this.weightUrlConverter(path)); } else { fetchURLs.push(pathPrefix + path + suffix); } } } if (!this.weightUrlConverter) return [3 /*break*/, 2]; _f = (_e = fetchURLs.push).apply; _g = [fetchURLs]; return [4 /*yield*/, Promise.all(urlPromises)]; case 1: _f.apply(_e, _g.concat([_h.sent()])); _h.label = 2; case 2: return [4 /*yield*/, loadWeightsAsArrayBuffer(fetchURLs, { requestInit: this.requestInit, fetchFunc: this.fetch, onProgress: this.onProgress })]; case 3: buffers = _h.sent(); return [2 /*return*/, [weightSpecs, concatenateArrayBuffers(buffers)]]; } }); }); }; HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; return HTTPRequest; }()); /** * Extract the prefix and suffix of the url, where the prefix is the path before * the last file, and suffix is the search params after the last file. * ``` * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file' * [prefix, suffix] = parseUrl(url) * // prefix = 'http://tfhub.dev/model/1/' * // suffix = '?tfjs-format=file' * ``` * @param url the model url to be parsed. */ function parseUrl(url) { var lastSlash = url.lastIndexOf('/'); var lastSearchParam = url.lastIndexOf('?'); var prefix = url.substring(0, lastSlash); var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ''; return [prefix + '/', suffix]; } function isHTTPScheme(url) { return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; } var httpRouter = function (url, loadOptions) { if (typeof fetch === 'undefined' && (loadOptions == null || loadOptions.fetchFunc == null)) { // `http` uses `fetch` or `node-fetch`, if one wants to use it in // an environment that is not the browser or node they have to setup a // global fetch polyfill. return null; } else { var isHTTP = true; if (Array.isArray(url)) { isHTTP = url.every(function (urlItem) { return isHTTPScheme(urlItem); }); } else { isHTTP = isHTTPScheme(url); } if (isHTTP) { return http(url, loadOptions); } } return null; }; IORouterRegistry.registerSaveRouter(httpRouter); IORouterRegistry.registerLoadRouter(httpRouter); /** * Creates an IOHandler subtype that sends model artifacts to HTTP server. * * An HTTP request of the `multipart/form-data` mime type will be sent to the * `path` URL. The form data includes artifacts that represent the topology * and/or weights of the model. In the case of Keras-style `tf.Model`, two * blobs (files) exist in form-data: * - A JSON file consisting of `modelTopology` and `weightsManifest`. * - A binary weights file consisting of the concatenated weight values. * These files are in the same format as the one generated by * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). * * The following code snippet exemplifies the client-side code that uses this * function: * * ```js * const model = tf.sequential(); * model.add( * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); * * const saveResult = await model.save(tf.io.http( * 'http://model-server:5000/upload', {requestInit: {method: 'PUT'}})); * console.log(saveResult); * ``` * * If the default `POST` method is to be used, without any custom parameters * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`: * * ```js * const saveResult = await model.save('http://model-server:5000/upload'); * ``` * * The following GitHub Gist * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 * implements a server based on [flask](https://github.com/pallets/flask) that * can receive the request. Upon receiving the model artifacts via the requst, * this particular server reconsistutes instances of [Keras * Models](https://keras.io/models/model/) in memory. * * * @param path A URL path to the model. * Can be an absolute HTTP path (e.g., * 'http://localhost:8000/model-upload)') or a relative path (e.g., * './model-upload'). * @param requestInit Request configurations to be used when sending * HTTP request to server using `fetch`. It can contain fields such as * `method`, `credentials`, `headers`, `mode`, etc. See * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request * for more information. `requestInit` must not have a body, because the * body will be set by TensorFlow.js. File blobs representing the model * topology (filename: 'model.json') and the weights of the model (filename: * 'model.weights.bin') will be appended to the body. If `requestInit` has a * `body`, an Error will be thrown. * @param loadOptions Optional configuration for the loading. It includes the * following fields: * - weightPathPrefix Optional, this specifies the path prefix for weight * files, by default this is calculated from the path param. * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js, * the `fetch` from node-fetch can be used here. * - onProgress Optional, progress callback function, fired periodically * before the load is completed. * @returns An instance of `IOHandler`. * * @doc { * heading: 'Models', * subheading: 'Loading', * namespace: 'io', * ignoreCI: true * } */ function http(path, loadOptions) { return new HTTPRequest(path, loadOptions); } /** * Deprecated. Use `tf.io.http`. * @param path * @param loadOptions */ function browserHTTPRequest(path, loadOptions) { return http(path, loadOptions); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var PassthroughLoader = /** @class */ (function () { function PassthroughLoader(modelArtifacts) { this.modelArtifacts = modelArtifacts; } PassthroughLoader.prototype.load = function () { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, this.modelArtifacts]; }); }); }; return PassthroughLoader; }()); var PassthroughSaver = /** @class */ (function () { function PassthroughSaver(saveHandler) { this.saveHandler = saveHandler; } PassthroughSaver.prototype.save = function (modelArtifacts) { return __awaiter(this, void 0, void 0, function () { return __generator(this, function (_a) { return [2 /*return*/, this.saveHandler(modelArtifacts)]; }); }); }; return PassthroughSaver; }()); /** * Creates an IOHandler that loads model artifacts from memory. * * When used in conjunction with `tf.loadLayersModel`, an instance of * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. * * ```js * const model = await tf.loadLayersModel(tf.io.fromMemory( * modelTopology, weightSpecs, weightData)); * ``` * * @param modelArtifacts a object containing model topology (i.e., parsed from * the JSON format). * @param weightSpecs An array of `WeightsManifestEntry` objects describing the * names, shapes, types, and quantization of the weight data. * @param weightData A single `ArrayBuffer` containing the weight data, * concatenated in the order described by the weightSpecs. * @param trainingConfig Model training configuration. Optional. * * @returns A passthrough `IOHandler` that simply loads the provided data. */ function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { if (arguments.length === 1) { var isModelArtifacts = modelArtifacts.modelTopology != null || modelArtifacts.weightSpecs != null; if (isModelArtifacts) { return new PassthroughLoader(modelArtifacts); } else { // Legacy support: with only modelTopology. // TODO(cais): Remove this deprecated API. console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.'); return new PassthroughLoader({ modelTopology: modelArtifacts }); } } else { // Legacy support. // TODO(cais): Remove this deprecated API. console.warn('Please call tf.io.fromMemory() with only one argument. ' + 'The argument should be of type ModelArtifacts. ' + 'The multi-argument signature of tf.io.fromMemory() has been ' + 'deprecated and will be removed in a future release.'); return new PassthroughLoader({ modelTopology: modelArtifacts, weightSpecs: weightSpecs, weightData: weightData, trainingConfig: trainingConfig }); } } /** * Creates an IOHandler that passes saved model artifacts to a callback. * * ```js * function handleSave(artifacts) { * // ... do something with the artifacts ... * return {modelArtifactsInfo: {...}, ...}; * } * * const saveResult = model.save(tf.io.withSaveHandler(handleSave)); * ``` * * @param saveHandler A function that accepts a `ModelArtifacts` and returns a * `SaveResult`. */ function withSaveHandler(saveHandler) { return new PassthroughSaver(saveHandler); } /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var io = { __proto__: null, browserFiles: browserFiles, browserHTTPRequest: browserHTTPRequest, concatenateArrayBuffers: concatenateArrayBuffers, decodeWeights: decodeWeights, encodeWeights: encodeWeights, fromMemory: fromMemory, getLoadHandlers: getLoadHandlers, getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON, getSaveHandlers: getSaveHandlers, http: http, isHTTPScheme: isHTTPScheme, loadWeights: loadWeights, registerLoadRouter: registerLoadRouter, registerSaveRouter: registerSaveRouter, weightsLoaderFactory: weightsLoaderFactory, withSaveHandler: withSaveHandler, copyModel: copyModel, listModels: listModels, moveModel: moveModel, removeModel: removeModel }; /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the dot product of two matrices, A * B. These must be matrices. * * ```js * const a = tf.tensor2d([1, 2], [1, 2]); * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); * * a.matMul(b).print(); // or tf.matMul(a, b) * ``` * @param a First matrix in dot product operation. * @param b Second matrix in dot product operation. * @param transposeA If true, `a` is transposed before multiplication. * @param transposeB If true, `b` is transposed before multiplication. * * @doc {heading: 'Operations', subheading: 'Matrices'} */ function matMul_(a, b, transposeA, transposeB) { var _a; if (transposeA === void 0) { transposeA = false; } if (transposeB === void 0) { transposeB = false; } var $a = convertToTensor(a, 'a', 'matMul'); var $b = convertToTensor(b, 'b', 'matMul'); _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; var inputs = { a: $a, b: $b }; var attrs = { transposeA: transposeA, transposeB: transposeB }; return ENGINE.runKernel(BatchMatMul, inputs, attrs); } var matMul = op({ matMul_: matMul_ }); /** * @license * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take * value `onValue` (defaults to 1), while all other locations take value * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank * `R+1` with the last axis of size `depth`. * * ```js * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print(); * ``` * * @param indices `tf.Tensor` of indices with dtype `int32`. * @param depth The depth of the one hot dimension. * @param onValue A number used to fill in the output when the index matches * the location. * @param offValue A number used to fill in the output when the index does * not match the location. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function oneHot_(indices, depth, onValue, offValue) { if (onValue === void 0) { onValue = 1; } if (offValue === void 0) { offValue = 0; } if (depth < 2) { throw new Error("Error in oneHot: depth must be >=2, but it is " + depth); } var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); var inputs = { indices: $indices }; var attrs = { depth: depth, onValue: onValue, offValue: offValue }; return ENGINE.runKernel(OneHot, inputs, attrs); } var oneHot = op({ oneHot_: oneHot_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. * * The returned `tf.Tensor`'s dimension `i` will correspond to the input * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, * where `n` is the rank of the input `tf.Tensor`. Hence by default, this * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. * * ```js * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); * * a.transpose().print(); // or tf.transpose(a) * ``` * * @param x The tensor to transpose. * @param perm The permutation of the dimensions of a. * * @doc {heading: 'Operations', subheading: 'Matrices'} */ function transpose_(x, perm) { var $x = convertToTensor(x, 'x', 'transpose'); if (perm == null) { perm = $x.shape.map(function (s, i) { return i; }).reverse(); } assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + ("must match length of perm " + perm + "."); }); perm.forEach(function (axis) { assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + (" but got " + perm); }); }); if ($x.rank <= 1) { return $x.clone(); } var inputs = { x: $x }; var attrs = { perm: perm }; return ENGINE.runKernel(Transpose, inputs, attrs); } var transpose = op({ transpose_: transpose_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Computes the confusion matrix from true labels and predicted labels. * * ```js * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32'); * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32'); * const numClasses = 3; * const out = tf.math.confusionMatrix(labels, predictions, numClasses); * out.print(); * // Expected output matrix: * // [[2, 0, 0], * // [0, 1, 1], * // [0, 0, 1]] * ``` * * @param labels The target labels, assumed to be 0-based integers * for the classes. The shape is `[numExamples]`, where * `numExamples` is the number of examples included. * @param predictions The predicted classes, assumed to be * 0-based integers for the classes. Must have the same shape as `labels`. * @param numClasses Number of all classes, as an integer. * Its value must be larger than the largest element in `labels` and * `predictions`. * @returns The confusion matrix as a int32-type 2D tensor. The value at * row `r` and column `c` is the number of times examples of actual class * `r` were predicted as class `c`. * * @doc {heading: 'Operations', subheading: 'Evaluation'} */ function confusionMatrix_(labels, predictions, numClasses) { var $labels = convertToTensor(labels, 'labels', 'confusionMatrix'); var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix'); assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () { return "If provided, numClasses must be a positive integer, " + ("but got " + numClasses); }); assert($labels.rank === 1, function () { return "Expected the rank of labels to be 1, but got " + $labels.rank; }); assert($predictions.rank === 1, function () { return "Expected the rank of predictions to be 1, " + ("but got " + $predictions.rank); }); assert($labels.shape[0] === $predictions.shape[0], function () { return "Mismatch in the number of examples: " + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + "Labels and predictions should have the same number of elements."; }); assert(numClasses > 0 && Number.isInteger(numClasses), function () { return "numClasses is required to be a positive integer, but got " + ("" + numClasses); }); // TODO(cais): In the future, if oneHot supports tensors inputs for // `numClasses`, `confusionMatrix` can make `numClasses` optional. var oneHotLabels = oneHot(cast($labels, 'int32'), numClasses); var oneHotPredictions = oneHot(cast($predictions, 'int32'), numClasses); var oneHotLabelsT = transpose(oneHotLabels); var product = matMul(oneHotLabelsT, oneHotPredictions); return cast(product, 'int32'); } var confusionMatrix = op({ confusionMatrix_: confusionMatrix_ }); /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var math = { __proto__: null, confusionMatrix: confusionMatrix }; /** * @license * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ /** * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype. * * The same functionality can be achieved with `tf.tensor`, but in general * we recommend using `tf.tensor3d` as it makes the code more readable. * * ```js * // Pass a nested array. * tf.tensor3d([[[1], [2]], [[3], [4]]]).print(); * ``` * ```js * // Pass a flat array and specify a shape. * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print(); * ``` * * @param values The values of the tensor. Can be nested array of numbers, * or a flat array, or a `TypedArray`. * @param shape The shape of the tensor. If not provided, it is inferred from * `values`. * @param dtype The data type. * * @doc {heading: 'Tensors', subheading: 'Creation'} */ function tensor3d(values, shape, dtype) { assertNonNull(values); if (shape != null && shape.length !== 3) { throw new Error('tensor3d() requires shape to have three numbers'); } var inferredShape = inferShape(values, dtype); if (inferredShape.length !== 3 && inferredShape.length !== 1) { throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray'); } if (inferredShape.length === 1 && shape == null) { throw new Error('tensor3d() requires shape to be provided when `values` ' + 'are a flat array'); } return makeTensor(values, shape, inferredShape, dtype); } /** * @license * Copyright 2019 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * ============================================================================= */ var fromPixels2DContext; /** * Creates a `tf.Tensor` from an image. * * ```js * const image = new ImageData(1, 1); * image.data[0] = 100; * image.data[1] = 150; * image.data[2] = 200; * image.data[3] = 255; * * tf.browser.fromPixels(image).print(); * ``` * * @param pixels The input image to construct the tensor from. The * supported image types are all 4-channel. You can also pass in an image * object with following attributes: * `{data: Uint8Array; width: number; height: number}` * @param numChannels The number of channels of the output tensor. A * numChannels value less than 4 allows you to ignore channels. Defaults to * 3 (ignores alpha channel of input image). * * @returns A Tensor3D with the shape `[height, width, numChannels]`. * * @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} */ function fromPixels_(pixels, numChannels) { if (numChannels === void 0) { numChannels = 3; } // Sanity checks. if (numChannels > 4) { throw new Error('Cannot construct Tensor with more than 4 channels from pixels.'); } if (pixels == null) { throw new Error('pixels passed to tf.browser.fromPixels() can not be null'); } var isPixelData = false; var isImageData = false; var isVideo = false; var isImage = false; var isCanvasLike = false; var isImageBitmap = false; if (pixels.data instanceof Uint8Array) { isPixelData = true; } else if (typeof (ImageData) !== 'undefined' && pixels instanceof ImageData) { isImageData = true; } else if (typeof (HTMLVideoElement) !== 'undefined' && pixels instanceof HTMLVideoElement) { isVideo = true; } else if (typeof (HTMLImageElement) !== 'undefined' && pixels instanceof HTMLImageElement) { isImage = true; // tslint:disable-next-line: no-any } else if (pixels.getContext != null) { isCanvasLike = true; } else if (typeof (ImageBitmap) !== 'undefined' && pixels instanceof ImageBitmap) { isImageBitmap = true; } else { throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + "in browser, or OffscreenCanvas, ImageData in webworker" + " or {data: Uint32Array, width: number, height: number}, " + ("but was " + pixels.constructor.name)); } if (isVideo) { var HAVE_CURRENT_DATA_READY_STATE = 2; if (isVideo && pixels.readyState < HAVE_CURRENT_DATA_READY_STATE) { throw new Error('The video element has not loaded data yet. Please wait for ' + '`loadeddata` event on the