| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755 |
- // Experimental
- import * as base from './base.js';
- import * as flatbuffers from './flatbuffers.js';
- import * as python from './python.js';
- const pytorch = {};
- const nnapi = {};
- const numpy = {};
- pytorch.ModelFactory = class {
- async match(context) {
- const reader = await pytorch.Reader.open(context);
- if (reader) {
- return context.set(reader.type, reader);
- }
- return null;
- }
- filter(context, match) {
- if (context.type === 'pytorch.export' && match.type === 'pytorch.zip') {
- return false;
- }
- if (context.type === 'pytorch.index' && match.type === 'pytorch.zip') {
- return false;
- }
- if (context.type === 'pytorch.model.json' && match.type === 'pytorch.data.pkl') {
- return false;
- }
- if (context.type === 'pytorch.model.json' && match.type === 'pickle') {
- return false;
- }
- return true;
- }
- async open(context) {
- const metadata = await pytorch.Metadata.open(context);
- const target = context.value;
- target.on('resolve', (sender, name) => {
- context.error(new pytorch.Error(`Unknown type name '${name}'.`), false);
- });
- await target.read(metadata);
- if (!target.format || (!target.modules && !target.module)) {
- throw new pytorch.Error("Reader not implemented.");
- }
- return new pytorch.Model(metadata, target);
- }
- };
- pytorch.Model = class {
- constructor(metadata, target) {
- this.format = target.format;
- this.producer = target.producer || '';
- this.modules = [];
- if (target.module) {
- const graph = new pytorch.Graph(target.execution, metadata, null, '', target.module);
- this.modules.push(graph);
- delete target.execution;
- } else if (target.modules) {
- for (const [name, value] of target.modules) {
- const graph = new pytorch.Graph(target.execution, metadata, null, name, value);
- this.modules.push(graph);
- delete target.execution;
- }
- }
- }
- };
- pytorch.Graph = class {
- constructor(execution, metadata, type, name = '', module = null) {
- this.nodes = [];
- this.inputs = [];
- this.outputs = [];
- this.name = name;
- this.type = type;
- const context = new pytorch.Context(execution, metadata);
- context.values.map = (name, type, tensor) => {
- if (tensor) {
- return new pytorch.Value(name, type, null, tensor);
- }
- if (!context.values.has(name)) {
- context.values.set(name, new pytorch.Value(name, type, null, tensor));
- } else if (type || tensor) {
- throw new pytorch.Error(`Duplicate value '${name}'.`);
- }
- return context.values.get(name);
- };
- const torch = execution ? execution.torch : null;
- if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) {
- const initializers = new Map();
- const graph = module.graph;
- const constants = module.code_with_constants[1].const_mapping;
- if (constants) {
- for (const [key, value] of constants) {
- const name = `CONSTANTS.${key}`;
- if (pytorch.Utility.isTensor(value)) {
- initializers.set(value, new pytorch.Tensor(context, name, value));
- } else if (pytorch.Utility.isObject(value)) {
- initializers.set(value, value);
- } else {
- // throw new pytorch.Error('Unsupported constant.');
- }
- }
- }
- const deleted = new Set();
- const param_node = graph.param_node();
- const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null;
- if (self) {
- const getattr = (value) => {
- if (value.value === undefined) {
- const node = value.node();
- if (node.kind() === 'prim::GetAttr') {
- const [input] = node.inputs();
- getattr(input);
- if (input.value !== undefined) {
- const name = node.s('name');
- value.value = input.value.__getattr__(name);
- value.identifier = input.identifier ? `${input.identifier}.${name}` : name;
- }
- }
- if (node === param_node && value === param_node.outputs()[0]) {
- value.value = module;
- value.identifier = '';
- }
- }
- };
- for (const node of graph.nodes()) {
- for (const input of node.inputs()) {
- getattr(input, node);
- }
- }
- const delattr = (value) => {
- for (const use of Array.from(value.uses())) {
- const node = use.user;
- if (node.kind() === 'prim::GetAttr') {
- for (const output of node.outputs()) {
- delattr(output);
- }
- // deleted.add(node);
- node.destroy();
- }
- }
- };
- delattr(param_node.outputs()[0], '');
- }
- for (const node of graph.nodes()) {
- if (node.kind() === 'prim::Constant' && node.outputs().length === 1) {
- const output = node.output();
- output.identifier = output.debugName();
- if (node.hasAttribute('value')) {
- const kind = node.kindOf('value');
- output.value = node[kind]('value');
- } else if (node.output().type() instanceof torch.NoneType) {
- output.value = null;
- }
- // deleted.add(node);
- node.destroy();
- }
- }
- for (const node of graph.nodes()) {
- if (node.kind() === 'prim::TupleUnpack') {
- const value = node.inputs()[0].value;
- if (Array.isArray(value) && value.length === node.outputs().length && value.every((value) => typeof value === 'number' || typeof value === 'string' || typeof value === 'boolean')) {
- for (let i = 0; i < node.outputs().length; i++) {
- const output = node.outputs()[i];
- output.value = value[i];
- }
- // deleted.add(node);
- node.destroy();
- }
- }
- }
- for (const node of graph.nodes()) {
- if (node.kind() === 'prim::ListConstruct' &&
- node.inputs().every((value) => typeof value.value === 'number' || typeof value.value === 'string' || typeof value.value === 'boolean') &&
- node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) {
- node.outputs()[0].value = node.inputs().map((value) => value.value);
- // deleted.add(node);
- node.destroy();
- }
- }
- for (const v of graph.inputs()) {
- if (self.uses().length === 0 && v === self) {
- continue;
- }
- const identifier = pytorch.Utility.unique(v);
- const name = v.debugName() || identifier;
- const value = context.values.map(identifier);
- this.inputs.push(new pytorch.Argument(name, [value]));
- }
- for (const value of graph.outputs()) {
- const identifier = pytorch.Utility.unique(value);
- this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
- }
- for (const node of graph.nodes()) {
- if (deleted.has(node)) {
- continue;
- }
- if (node === graph.param_node() ||
- node === graph.return_node()) {
- continue;
- }
- if (node.kind() === 'prim::ListConstruct') {
- if (node.outputs().length === 1 &&
- node.outputs().every((output) => output.uses().length === 1) &&
- node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) {
- continue;
- }
- }
- this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context));
- }
- } else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) {
- const exported_program = module;
- const graph = exported_program.graph;
- const graph_module = exported_program.graph_module;
- const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters;
- const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers;
- const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants;
- const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
- for (const obj of graph.nodes) {
- if (obj.op === 'placeholder') {
- if (inputs_to_parameters.has(obj.name)) {
- const key = inputs_to_parameters.get(obj.name);
- const parameter = exported_program.state_dict.get(key);
- const tensor = parameter ? (parameter.data || parameter) : obj.meta.get('val');
- const initializer = new pytorch.Tensor(context, key, tensor);
- const value = new pytorch.Value(key, null, null, initializer);
- context.values.set(obj, value);
- } else if (inputs_to_buffers.has(obj.name)) {
- const key = inputs_to_buffers.get(obj.name);
- const buffer = exported_program.state_dict.get(key);
- const tensor = buffer || obj.meta.get('val');
- const initializer = new pytorch.Tensor(context, key, tensor);
- const value = new pytorch.Value(key, null, null, initializer);
- context.values.set(obj, value);
- } else if (inputs_to_lifted_tensor_constants.has(obj.name)) {
- const key = inputs_to_lifted_tensor_constants.get(obj.name);
- const constant = exported_program.constants.get(key);
- const tensor = constant && constant.data ? constant.data : obj.meta.get('val');
- const initializer = new pytorch.Tensor(context, key, tensor);
- const value = new pytorch.Value(key, null, null, initializer);
- context.values.set(obj, value);
- }
- if (obj.users.size > 1 && context.values.has(obj)) {
- const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context);
- this.nodes.push(node);
- context.values.set(obj, node.outputs[0].value[0]);
- }
- }
- }
- context.graph(this, graph_module, false);
- for (const input_spec of exported_program.graph_signature.user_inputs) {
- if (nodes.has(input_spec)) {
- const node = nodes.get(input_spec);
- const value = context.value(node);
- const argument = new pytorch.Argument(input_spec, [value]);
- this.inputs.push(argument);
- }
- }
- } else if (torch && module instanceof torch.fx.GraphModule && module.graph) {
- context.graph(this, module, true);
- } else if (pytorch.Utility.isTensor(module)) {
- const node = new pytorch.Node(execution, metadata, null, type, { value: module }, null, context);
- this.nodes.push(node);
- } else {
- const weights = this.type === 'weights' ? module : pytorch.Utility.weights(module);
- if (weights) {
- this.name = !this.name && typeof module.__name__ === 'string' ? module.__name__ : this.name;
- for (const [name, module] of weights) {
- const node = new pytorch.Node(execution, metadata, name, 'Weights', module, null, context);
- this.nodes.push(node);
- }
- } else {
- const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module];
- for (const module of modules) {
- const type = this.type === 'weights' ? 'Weights' : null;
- const node = new pytorch.Node(execution, metadata, null, type, module, null, context);
- this.nodes.push(node);
- }
- }
- }
- }
- };
- pytorch.Argument = class {
- constructor(name, value, type = null, visible = true) {
- this.name = name;
- this.value = value;
- this.type = type;
- this.visible = visible;
- }
- };
- pytorch.Value = class Value {
- constructor(name, type, quantization, initializer = null) {
- if (typeof name !== 'string') {
- throw new pytorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
- }
- this.name = name;
- this.type = initializer && initializer.type ? initializer.type : type || null;
- this.quantization = quantization;
- this.initializer = initializer;
- }
- };
- pytorch.Node = class {
- constructor(execution, metadata, name, type, obj, initializers, context, stack) {
- const torch = execution ? execution.torch : null;
- const builtins = execution ? execution.builtins : null;
- this.name = name || '';
- this.nodes = [];
- this.attributes = [];
- this.inputs = [];
- this.outputs = [];
- this.blocks = [];
- this.metadata = [];
- if (torch && obj instanceof torch.Node) {
- const node = obj;
- const kind = node.kind();
- const schema = node.schema();
- const inputs = node.inputs();
- const outputs = node.outputs();
- this.type = {
- name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0],
- identifier: kind
- };
- if (schema && schema.category) {
- this.type.category = schema.category;
- }
- const getAttribute = (node, name) => {
- const kind = node.kindOf(name);
- let value = null;
- let type = null;
- switch (kind) {
- case 's': value = node.s(name); type = 'string'; break;
- case 'i': value = node.i(name); type = 'int64'; break;
- case 'f': value = node.f(name); type = 'float32'; break;
- case 't': value = node.t(name); type = 'tensor'; break;
- case 'ss': value = node.ss(name); type = 'string[]'; break;
- case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break;
- case 'ival': value = node.ival(name); break;
- default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
- }
- return [type, value];
- };
- for (const name of node.attributeNames()) {
- const [type, value] = getAttribute(node, name);
- const attribute = new pytorch.Argument(name, value, type);
- this.attributes.push(attribute);
- }
- const mapTensor = (value) => {
- if (value.identifier && pytorch.Utility.isTensor(value.value)) {
- const identifier = value.identifier;
- if (!context.values.has(identifier)) {
- const tensor = new pytorch.Tensor(context, identifier, value.value);
- context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
- }
- return context.values.map(identifier);
- }
- let initializer = null;
- let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
- if (value.value) {
- const obj = value.value;
- const hide = obj.__parent__ ? obj.__parent__.__hide__ : true;
- initializer = hide ? initializers.get(obj) : null;
- identifier = initializer ? initializer.name : identifier;
- }
- if (initializer) {
- return new pytorch.Value(identifier, null, null, initializer);
- }
- return context.values.map(identifier);
- };
- for (let i = 0; i < inputs.length; i++) {
- const input = inputs[i];
- const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
- const name = arg && arg.name ? arg.name : i.toString();
- let type = arg ? arg.real_type : null;
- let array = false;
- if (type instanceof torch.ListType) {
- array = true;
- type = type.getElementType();
- }
- let argument = null;
- if (type && type instanceof torch.ClassType) {
- const obj = input.value;
- if (!array && initializers.has(obj)) {
- const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context);
- argument = new pytorch.Argument(name, node, 'object');
- } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
- const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context));
- argument = new pytorch.Argument(name, node, 'object[]');
- } else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) {
- const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context));
- argument = new pytorch.Argument(name, node, 'object[]');
- } else if (input.value === undefined) {
- const identifier = pytorch.Utility.unique(input);
- const value = context.values.map(identifier);
- argument = new pytorch.Argument(name, [value]);
- } else {
- const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context);
- argument = new pytorch.Argument(name, node, 'object');
- }
- } else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) {
- const value = mapTensor(input);
- argument = new pytorch.Argument(name, [value]);
- } else if (input instanceof torch.Value && !pytorch.Utility.isTensor(input.value)) {
- if (input.value !== undefined) {
- if (Array.isArray(input.value) && input.value.every((value) => pytorch.Utility.isTensor(value))) {
- continue;
- }
- const type = input.type() ? pytorch.Utility.toType(input.type()) : null;
- let value = input.value;
- if (value && value instanceof torch._C.IValue) {
- value = pytorch.Utility.toString(value);
- }
- if (value && value instanceof builtins.complex) {
- value = new base.Complex(value.real, value.imag);
- }
- argument = new pytorch.Argument(name, value, type || 'attribute');
- } else if (input.type() instanceof torch.ListType) {
- if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 &&
- input.node().inputs().every((value) => value instanceof torch.Value || value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType || value.type() instanceof torch.TensorType)) {
- const list = input.node().inputs();
- const args = list.map((value) => {
- if (pytorch.Utility.isTensor(value.value)) {
- return mapTensor(value);
- }
- if (value.value !== undefined) {
- return value.value;
- }
- const identifier = pytorch.Utility.unique(value);
- return context.values.map(identifier);
- });
- const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
- argument = new pytorch.Argument(name, args, type);
- } else {
- const identifier = pytorch.Utility.unique(input);
- argument = new pytorch.Argument(name, [context.values.map(identifier)]);
- }
- } else if (input.type() instanceof torch.StringType && typeof input.value === 'string') {
- argument = new pytorch.Argument(name, input.value, 'string');
- } else if (input.type() instanceof torch.BoolType && (typeof input.value === 'boolean' || input.value === 0 || input.value === 1)) {
- argument = new pytorch.Argument(name, Boolean(input.value), 'boolean');
- } else if (input.type() instanceof torch.IntType && typeof input.value === 'number') {
- argument = new pytorch.Argument(name, input.value, 'int64');
- } else if (input.type() instanceof torch.FloatType && typeof input.value === 'number') {
- argument = new pytorch.Argument(name, input.value, 'float32');
- } else if (input.type() instanceof torch.NoneType && input.value === null) {
- argument = new pytorch.Argument(name, null, 'attribute');
- } else {
- const identifier = pytorch.Utility.unique(input);
- const value = context.values.map(identifier);
- argument = new pytorch.Argument(name, [value]);
- }
- } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
- let list = [input];
- if (input.node() && node !== input.node() &&
- input.node().kind() === 'prim::ListConstruct' &&
- input.uses().length === 1 &&
- input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
- list = input.node().inputs();
- }
- const args = list.map((input) => {
- let initializer = null;
- let identifier = pytorch.Utility.unique(input);
- if (input.value) {
- const value = input.value;
- const hide = value.__parent__ ? value.__parent__.__hide__ : true;
- initializer = hide ? initializers.get(value) : null;
- identifier = initializer ? initializer.name : identifier;
- }
- if (initializer) {
- return new pytorch.Value(identifier, null, null, initializer);
- }
- return context.values.map(identifier);
- });
- argument = new pytorch.Argument(name, args);
- } else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) {
- const args = input.value.map((value) => {
- if (value instanceof torch.Value) {
- const identifier = pytorch.Utility.unique(value);
- return context.values.map(identifier);
- }
- return value;
- });
- argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type));
- } else {
- throw new pytorch.Error('Unsupported input value');
- }
- this.inputs.push(argument);
- }
- for (let i = 0; i < outputs.length; i++) {
- const output = outputs[i];
- const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null;
- if (ret && ret.name) {
- name = ret.name;
- } else {
- name = i === 0 && outputs.length === 1 ? 'output' : `${i}`;
- }
- let list = [output];
- if (output.uses().length === 1 &&
- output.uses()[0].user &&
- output.uses()[0].user.kind() === 'prim::ListUnpack' &&
- output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
- list = output.uses()[0].user.outputs();
- }
- const args = list.map((output) => context.values.map(pytorch.Utility.unique(output)));
- const argument = new pytorch.Argument(name, args);
- this.outputs.push(argument);
- }
- const blocks = node.blocks();
- for (let i = 0; i < blocks.length; i++) {
- const block = blocks[i];
- const nodes = Array.from(block.nodes());
- if (nodes.length > 0) {
- const name = `block${i.toString()}`;
- const graph = { name: '', nodes: [] }; // new pytorch.Graph(execution, metadata, null, name, blocks[i]);
- const argument = new pytorch.Argument(name, graph, 'graph');
- this.blocks.push(argument);
- }
- }
- const sourceRange = node.sourceRange();
- if (sourceRange) {
- this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''), 'attribute'));
- if (sourceRange.source()) {
- const orig = sourceRange.source().findSourceRangeThatGenerated(sourceRange);
- if (orig) {
- this.metadata.push(new pytorch.Argument('generated', orig.toString(), 'attribute'));
- }
- }
- }
- } else if (torch && obj instanceof torch.fx.node.Node) {
- if (obj.op === 'call_function') {
- let name = null;
- const target = obj.target;
- if (target instanceof torch._ops.OpOverload) {
- name = target.name();
- } else if (target instanceof torch._ops.HigherOrderOperator) {
- name = `${target.namespace}::${target.name}`;
- } else if (builtins.isinstance(target, builtins.function)) {
- name = target.__name__;
- } else if (typeof target === 'string') {
- // Handle unresolved operators
- const match = target.match(/^torch\.ops\.([^.]+)\.(.+)$/);
- if (!match) {
- throw new pytorch.Error(`Unsupported target '${target}'.`);
- }
- const [, namespace, opname] = match;
- name = `${namespace}::${opname}`;
- } else {
- throw new pytorch.Error(`Unsupported target '${target}'.`);
- }
- this.type = {
- identifier: name,
- name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]
- };
- const schema = obj.target._schema;
- if (schema && schema.category) {
- this.type.category = schema.category;
- }
- let args = obj.args.map((arg, index) => {
- if (!schema) {
- return ['', arg];
- }
- if (Array.isArray(schema.arguments) && index < schema.arguments.length) {
- return [schema.arguments[index].name, arg];
- }
- if (schema.is_vararg) {
- return ['', arg];
- }
- throw new pytorch.Error('Unsupported schema argument.');
- });
- const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg]));
- args = args.concat(Array.from(obj.kwargs));
- for (const [name, arg] of args) {
- let type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
- if (arg instanceof torch.fx.node.Node) {
- let argument = null;
- if (arg.op === 'get_attr' && arg.users.size === 1) {
- const subgraph = context.function(arg);
- if (subgraph) {
- argument = new pytorch.Argument(name, subgraph, 'function');
- }
- }
- if (!argument) {
- const value = context.value(arg);
- argument = new pytorch.Argument(name, [value]);
- }
- this.inputs.push(argument);
- } else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) {
- const list = arg.map((arg) => arg === null ? null : context.value(arg));
- const argument = new pytorch.Argument(name, list);
- this.inputs.push(argument);
- } else if (Array.isArray(arg)) {
- const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg);
- const argument = new pytorch.Argument(name, list, type || 'attribute');
- this.inputs.push(argument);
- } else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) {
- const argument = new pytorch.Argument(name, arg.toString(), type || 'attribute');
- this.inputs.push(argument);
- } else {
- const primitive = typeof arg === 'number' || typeof arg === 'boolean' || typeof arg === 'string' || arg === null;
- type = type === 'tensor' && primitive ? null : type;
- const argument = new pytorch.Argument(name, arg, type || 'attribute');
- this.inputs.push(argument);
- }
- }
- let outputs = [obj];
- if (obj.users.size > 1) {
- const users = Array.from(obj.users.keys());
- if (users.every((user) => user.op === 'call_function' && user.target.__module__ === 'operator' && user.target.__name__ === 'getitem')) {
- outputs = new Array(obj.users.size);
- for (const user of users) {
- const [, index] = user.args;
- outputs[index] = user;
- }
- }
- }
- for (let i = 0; i < outputs.length; i++) {
- const node = outputs[i];
- const value = context.value(node);
- const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
- const argument = new pytorch.Argument(name, [value]);
- this.outputs.push(argument);
- }
- for (const [name, value] of obj.meta) {
- if (name === 'val' || name === 'torch_fn' ||
- (Array.isArray(value) && value.length === 0) ||
- (value instanceof Map && value.size === 0)) {
- continue;
- }
- if (typeof value === 'string') {
- const argument = new pytorch.Argument(name, value, 'string');
- this.metadata.push(argument);
- } else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) {
- const argument = new pytorch.Argument(name, value, 'string[]');
- this.metadata.push(argument);
- } else if (value instanceof Map && value.size > 0) {
- // const argument = new pytorch.Argument(name, Object.fromEntries(Array.from(value)));
- // this.metadata.push(argument);
- } else {
- // const argument = new pytorch.Argument(name, value);
- // this.metadata.push(argument);
- }
- }
- } else if (obj.op === 'placeholder') {
- this.type = { name: obj.op };
- {
- const value = context.value(obj);
- const argument = new pytorch.Argument('value', [value]);
- this.inputs.push(argument);
- }
- {
- const node = new torch.fx.node.Node(null, obj.name);
- node.meta = obj.meta;
- const value = context.value(node);
- const argument = new pytorch.Argument('value', [value]);
- this.outputs.push(argument);
- }
- } else if (obj.op === 'get_attr') {
- this.type = { name: obj.op };
- const subgraph = context.function(obj);
- if (subgraph) {
- this.inputs.push(new pytorch.Argument('name', subgraph, 'function'));
- } else {
- this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
- }
- const value = context.value(obj);
- this.outputs.push(new pytorch.Argument('value', [value]));
- } else if (obj.op === 'root') {
- this.type = { name: obj.op };
- } else {
- throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
- }
- } else {
- if (torch && obj instanceof torch.ScriptObject) {
- type = obj._type().qualified_name();
- obj = obj._ivalue;
- } else if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
- type = obj._c._type();
- const target = {
- _modules: obj._modules,
- _parameters: obj._parameters,
- _buffers: obj._buffers,
- };
- for (let i = 0; i < type.numAttributes(); i++) {
- if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) {
- const k = type.getAttributeName(i);
- target[k] = obj.__getattr__(k);
- }
- }
- type = obj._c.qualified_name;
- obj = target;
- }
- if (!type) {
- if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
- type = obj._c.qualified_name;
- } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) {
- type = `${obj.__module__}.${obj.__name__}`;
- obj = {};
- } else if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
- type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
- } else {
- type = 'builtins.object';
- }
- }
- if (type instanceof nnapi.Graph) {
- this.type = type;
- } else {
- const key = type.startsWith('__torch__.') ? type.substring(10) : type;
- const value = metadata.type(key);
- this.type = value ? { ...value } : { name: type };
- this.type.identifier = type;
- }
- stack = stack || new Set();
- const weights = pytorch.Utility.weights(obj);
- if (weights) {
- const type = this.type.name;
- this.type = new pytorch.Graph(execution, metadata, 'weights', '', weights);
- this.type.name = type;
- } else if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) {
- // continue
- } else if (obj && pytorch.Utility.isInstance(obj, '__torch__.torch.classes._nnapi.Compilation')) {
- // continue
- } else if (obj && type === 'builtins.bytearray') {
- const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]');
- this.inputs.push(argument);
- } else if (obj) {
- const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []);
- const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
- for (const [name, value] of list) {
- if (name === '__class__' || name === '__name__') {
- continue;
- } else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) {
- continue;
- } else if (pytorch.Utility.isInstance(value, 'builtins.set') && value instanceof Set && value.size === 0) {
- continue;
- } else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) {
- continue;
- } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) {
- continue;
- }
- let parameters = null;
- if ((name === '_parameters' || name === '_buffers') && value instanceof Map) {
- parameters = value;
- } else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) {
- parameters = new Map([[name, value]]);
- }
- if (parameters) {
- for (const [name, value] of parameters) {
- const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)];
- const visible = inputs.has(name) ? inputs.get(name).visible || true : true;
- const args = list.filter((value) => value !== null && !value.__origin__).map((value) => {
- const name = value && value.name ? value.name : '';
- const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name;
- let tensor = null;
- if (initializers && initializers.has(value)) {
- tensor = initializers.get(value);
- } else {
- value = value.__source__ ? value.__source__ : value;
- tensor = value ? new pytorch.Tensor(context, identifier, value) : null;
- }
- return new pytorch.Value(identifier, null, null, tensor);
- });
- const argument = new pytorch.Argument(name, args, null, visible);
- this.inputs.push(argument);
- if (value && value.__variable__) {
- const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]);
- this.outputs.push(argument);
- }
- }
- continue;
- }
- if (pytorch.Utility.isTensor(value)) {
- const tensor = new pytorch.Tensor(context, '', value);
- const argument = new pytorch.Argument(name, tensor, 'tensor');
- this.inputs.push(argument);
- } else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) {
- const node = new pytorch.Node(execution, metadata, null, value.toString(), {}, null, context);
- const argument = new pytorch.Argument(name, node, 'object');
- this.inputs.push(argument);
- } else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) {
- const tensors = value.map((value) => value === null ? value : new pytorch.Tensor(context, '', value));
- const argument = new pytorch.Argument(name, tensors, 'tensor[]');
- this.inputs.push(argument);
- } else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) {
- const tensor = new numpy.Tensor(value);
- const argument = new pytorch.Argument(name, tensor, 'tensor');
- this.inputs.push(argument);
- } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) {
- const argument = new pytorch.Argument(name, value, 'string[]');
- this.inputs.push(argument);
- } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) {
- const argument = new pytorch.Argument(name, value, 'attribute');
- this.inputs.push(argument);
- } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') &&
- value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
- const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
- stack.add(value);
- const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
- const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack);
- stack.delete(value);
- return node;
- });
- const argument = new pytorch.Argument(name, list, 'object[]');
- this.inputs.push(argument);
- } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) {
- const argument = new pytorch.Argument(name, value, 'attribute');
- this.inputs.push(argument);
- } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
- const list = value.filter((value) => !stack.has(value));
- const nodes = list.map((value) => {
- stack.add(value);
- const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
- stack.delete(value);
- return node;
- });
- const argument = new pytorch.Argument(name, nodes, 'object[]');
- this.inputs.push(argument);
- } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) {
- stack.add(value);
- const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
- stack.delete(value);
- const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value);
- const argument = new pytorch.Argument(name, node, 'object', visible);
- this.inputs.push(argument);
- } else {
- let schema = metadata.attribute(this.type.identifier, name);
- schema = name === 'training' ? { type: 'boolean', visible: false } : schema;
- let visible = true;
- let obj = value;
- const type = schema && schema.type ? schema.type : 'attribute';
- if (schema) {
- if (schema.visible === false) {
- visible = false;
- } else if (schema.default !== undefined) {
- if (Array.isArray(obj)) {
- if (Array.isArray(schema.default)) {
- visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]);
- } else {
- visible = !obj.every((item) => item === schema.default);
- }
- } else {
- visible = obj !== schema.default;
- }
- }
- }
- if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
- obj = '?';
- }
- const argument = new pytorch.Argument(name, obj, type, visible);
- this.inputs.push(argument);
- }
- }
- }
- }
- }
- };
- pytorch.Tensor = class {
- constructor(context, name, tensor) {
- this.name = name || '';
- this.attributes = [];
- tensor = tensor.data ? tensor.data : tensor;
- const storage = tensor.storage();
- this.type = context.type(tensor);
- const layout = this.type.layout;
- const size = this.type.shape.dimensions || [];
- if (layout) {
- this.indices = new pytorch.Tensor(context, '', tensor.indices);
- this._values = new pytorch.Tensor(context, '', tensor.values);
- } else {
- this.encoding = '<';
- this.indices = null;
- this.stride = tensor.stride();
- const stride = this.stride;
- const offset = tensor.storage_offset();
- if (storage) {
- this._data = storage.data;
- let length = 0;
- if (!Array.isArray(stride)) {
- length = storage.size();
- } else if (size.every((v) => v !== 0)) {
- length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
- }
- if (storage && typeof storage.size === 'function') {
- if (offset !== 0 || length !== storage.size()) {
- const itemsize = storage.dtype.itemsize();
- this._offset = itemsize * offset;
- this._length = itemsize * length;
- }
- }
- }
- }
- const type = tensor.__class__ || {};
- if (type.tensor_attribute_names) {
- for (const name of type.tensor_attribute_names) {
- let value = tensor[name];
- if (value !== undefined) {
- if (value && typeof value.__reduce__ === 'function') {
- value = value.__reduce__();
- }
- const attribute = new pytorch.Argument(name, value, 'attribute');
- this.attributes.push(attribute);
- }
- }
- }
- if (type.tensor_data_names) {
- for (const name of type.tensor_data_names) {
- const value = tensor[name];
- if (value !== undefined && pytorch.Utility.isTensor(value)) {
- const attribute = new pytorch.Argument(name, new pytorch.Tensor(context, name, value), 'tensor');
- this.attributes.push(attribute);
- }
- }
- }
- }
- get values() {
- const type = this.type.layout;
- if (type && type.startsWith('sparse.')) {
- return this._values;
- }
- if (this._data instanceof Uint8Array) {
- return this._data;
- }
- if (this._data && this._offset !== undefined) {
- const stream = this._data;
- const position = stream.position;
- stream.seek(this._offset);
- const values = stream.peek(this._length);
- stream.seek(position);
- return values;
- }
- if (this._data) {
- return this._data.peek();
- }
- return null;
- }
- };
- pytorch.TensorType = class {
- constructor(dataType, shape, layout) {
- this.dataType = dataType;
- this.shape = shape;
- this.layout = layout;
- }
- toString() {
- return this.dataType + this.shape.toString();
- }
- };
- pytorch.TensorShape = class {
- constructor(dimensions = []) {
- this.dimensions = dimensions;
- }
- toString() {
- if (this.dimensions && this.dimensions.length > 0) {
- return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
- }
- return '';
- }
- };
- pytorch.Context = class {
- constructor(execution, metadata) {
- this.execution = execution;
- this.torch = execution ? execution.__import__('torch') : null;
- this.metadata = metadata;
- this.values = new Map();
- this.modules = new Map();
- }
- type(tensor) {
- let dataType = tensor.dtype.__reduce__();
- switch (dataType) {
- case 'float8_e5m2': dataType = 'float8e5m2'; break;
- case 'float8_e5m2fnuz': dataType = 'float8e5m2fnuz'; break;
- case 'float8_e4m3fn': dataType = 'float8e4m3fn'; break;
- case 'float8_e4m3fnuz': dataType = 'float8e4m3fnuz'; break;
- case 'float8_e8m0fnu': dataType = 'float8e8m0fnu'; break;
- case 'float4_e2m1fn_x2': dataType = 'float4e2m1fnx2'; break;
- default: break;
- }
- const size = tensor.size ? tensor.size() : tensor.shape;
- const shape = new pytorch.TensorShape(size || []);
- const layout = tensor.layout ? tensor.layout.__str__() : null;
- if (layout && layout.startsWith('torch.sparse_')) {
- return new pytorch.TensorType(dataType, shape, layout.split('.').pop().replace('_', '.'));
- }
- return new pytorch.TensorType(dataType, shape);
- }
- value(obj) {
- const torch = this.torch;
- if (obj instanceof torch.fx.node.Node) {
- if (!this.values.has(obj)) {
- let type = null;
- const val = obj.meta ? obj.meta.get('val') : null;
- if (val && val.dtype) {
- type = this.type(val);
- }
- const value = new pytorch.Value(obj.name, type);
- this.values.set(obj, value);
- }
- return this.values.get(obj);
- }
- return null;
- }
- function(obj) {
- const torch = this.torch;
- if (obj instanceof torch.fx.node.Node) {
- let subgraph = this.modules.get(obj);
- if (subgraph) {
- if (subgraph instanceof pytorch.Graph === false) {
- subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph);
- this.modules.set(obj, subgraph);
- }
- return subgraph;
- }
- }
- return null;
- }
- graph(target, module, inputs) {
- const graph = module.graph;
- if (module.named_modules) {
- const modules = module.named_modules();
- for (const obj of graph.nodes) {
- if (obj.op === 'get_attr') {
- const submodule = modules.get(obj.target);
- if (submodule && submodule.graph) {
- this.modules.set(obj, submodule);
- }
- }
- }
- }
- let controlDependency = null;
- for (const obj of graph.nodes) {
- if (obj.op === 'placeholder') {
- if (inputs) {
- const value = this.value(obj);
- const argument = new pytorch.Argument(obj.name, [value]);
- target.inputs.push(argument);
- }
- continue;
- }
- if (obj.op === 'call_function') {
- if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
- continue;
- }
- }
- if (obj.op === 'get_attr') {
- if (this.modules.has(obj) && obj.users.size === 1) {
- continue;
- }
- }
- if (obj.op === 'output') {
- for (const output of obj.args) {
- if (output === null || output === undefined) {
- continue;
- }
- if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
- continue;
- }
- const value = this.value(output);
- const argument = new pytorch.Argument(output.name, [value]);
- target.outputs.push(argument);
- }
- continue;
- }
- const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this);
- target.nodes.push(node);
- if (controlDependency) {
- node.controlDependencies = node.controlDependencies || [];
- node.controlDependencies.push(controlDependency);
- controlDependency = null;
- }
- if (obj.op === 'call_function' && obj.users.size === 0) {
- controlDependency = node.outputs[0].value[0];
- }
- }
- }
- };
- pytorch.Reader = class {
- static async open(context) {
- const types = [
- pytorch.Reader.Zip,
- pytorch.Reader.Pickle,
- pytorch.Reader.Tar,
- pytorch.Reader.data_pkl,
- pytorch.Reader.torch_utils,
- pytorch.Reader.Mobile,
- pytorch.Reader.ModelJson,
- pytorch.Reader.IR,
- pytorch.Reader.Index,
- pytorch.Reader.ExportedProgram
- ];
- for (const type of types) {
- // eslint-disable-next-line no-await-in-loop
- const reader = await type.open(context);
- if (reader) {
- return reader;
- }
- }
- return null;
- }
- constructor() {
- this._events = [];
- }
- async read() {
- }
- on(event, callback) {
- this._events.push([event, callback]);
- }
- };
- pytorch.Reader.Tar = class extends pytorch.Reader {
- static async open(context) {
- const entries = await context.peek('tar');
- if (entries instanceof Map && entries.has('pickle')) {
- return new pytorch.Reader.Tar(entries);
- }
- return null;
- }
- constructor(entries) {
- super();
- this.type = 'pytorch.tar';
- this.entries = entries;
- }
- async read() {
- this.format = 'PyTorch v0.1.1';
- const execution = new python.Execution();
- for (const event of this._events) {
- execution.on(event[0], event[1]);
- }
- const torch = execution.__import__('torch');
- this.module = torch.load(this.entries);
- delete this.entries;
- }
- };
- pytorch.Reader.Pickle = class extends pytorch.Reader {
- static async open(context) {
- const stream = context.stream;
- const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
- if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
- return new pytorch.Reader.Pickle(stream);
- }
- return null;
- }
- constructor(stream) {
- super();
- this.type = 'pytorch.pickle';
- this.stream = stream;
- }
- async read() {
- this.format = 'PyTorch v0.1.10';
- const data = this.stream.length < 0x7ffff000 ? this.stream.peek() : this.stream;
- delete this.stream;
- const execution = new python.Execution();
- for (const event of this._events) {
- execution.on(event[0], event[1]);
- }
- const torch = execution.__import__('torch');
- this.module = torch.load(data);
- }
- };
- pytorch.Reader.data_pkl = class extends pytorch.Reader {
- static async open(context) {
- const obj = await context.peek('pkl');
- if (obj) {
- if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
- const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
- if (name.startsWith('__torch__.')) {
- return new pytorch.Reader.data_pkl('', obj);
- }
- }
- if (pytorch.Utility.isTensor(obj)) {
- return new pytorch.Reader.data_pkl('tensor', obj);
- }
- if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
- return new pytorch.Reader.data_pkl('tensor', obj);
- }
- if (obj instanceof Map) {
- const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
- if (entries.length > 0) {
- return new pytorch.Reader.data_pkl('tensor', obj);
- }
- } else if (!Array.isArray(obj)) {
- const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
- if (entries.length > 0) {
- return new pytorch.Reader.data_pkl('tensor', obj);
- }
- }
- for (const key of ['', 'model', 'net']) {
- const module = key === '' ? obj : obj[key];
- if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) {
- return new pytorch.Reader.data_pkl('module', module);
- }
- }
- }
- return null;
- }
- constructor(type, module) {
- super();
- this.type = 'pytorch.data.pkl';
- this.format = 'PyTorch Pickle';
- this.module = module;
- }
- async read() {
- }
- };
- pytorch.Reader.torch_utils = class extends pytorch.Reader {
- static async open(context) {
- const stream = context.stream;
- if (stream && stream.length > 1) {
- const buffer = stream.peek(Math.min(1024, stream.length));
- if (buffer[0] === 0x80) {
- const content = String.fromCharCode.apply(null, buffer);
- if (content.indexOf('torch_utils') !== -1) {
- const obj = await context.peek('pkl');
- if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) {
- return new pytorch.Reader.torch_utils(obj);
- }
- }
- }
- }
- return null;
- }
- constructor(obj) {
- super();
- this.type = 'pytorch.torch_utils';
- this.obj = obj;
- }
- async read() {
- this.format = 'PyTorch torch_utils';
- this.module = this.obj;
- delete this.obj;
- }
- };
- pytorch.Reader.Mobile = class extends pytorch.Reader {
- static async open(context) {
- const reader = await context.peek('flatbuffers.binary');
- if (reader && reader.identifier === 'PTMF') {
- return new pytorch.Reader.Mobile(context);
- }
- return null;
- }
- constructor(context) {
- super();
- this.type = 'pytorch.mobile';
- this.context = context;
- }
- async read(metadata) {
- const execution = new pytorch.Execution(null, metadata);
- for (const event of this._events) {
- execution.on(event[0], event[1]);
- }
- const stream = this.context.stream;
- const torch = execution.__import__('torch');
- torch.mobile = await this.context.require('./pytorch-schema');
- torch.mobile = torch.mobile.torch.jit.mobile;
- this.module = torch.jit.jit_module_from_flatbuffer(stream);
- const version = this.module._c._bytecode_version.toString();
- this.format = pytorch.Utility.format('PyTorch Mobile', version);
- delete this.context;
- }
- };
- pytorch.Reader.Zip = class extends pytorch.Reader {
- static async open(context) {
- const entries = await context.peek('zip');
- if (entries instanceof Map && entries.size > 0) {
- let prefix = 0;
- const paths = Array.from(entries.keys()).map((path) => path.replace(/\\/g, '/').split('/').reverse());
- for (let set = new Set(); set && paths.length > 0;) {
- set = new Set(paths.map((path) => path.length > 1 ? path.pop() : null));
- set = set.size > 1 || set.keys().next().value === null ? null : set;
- prefix += set ? set.keys().next().value.length + 1 : 0;
- }
- const records = new Map(Array.from(entries).map(([name, value]) => [name.substring(prefix), value]));
- if (records.has('model.json')) {
- return null;
- }
- if (records.has('data.pkl')) {
- return new pytorch.Reader.Zip(entries);
- }
- if (records.has('.data/version') && !records.has('archive_format')) {
- return new pytorch.Reader.Package(entries);
- }
- }
- return null;
- }
- constructor(entries) {
- super();
- this.type = 'pytorch.zip';
- // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md
- // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
- this._entries = entries;
- }
- async read(metadata) {
- this.execution = new pytorch.Execution(null, metadata);
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- const torch = this.execution.__import__('torch');
- const reader = new torch.PyTorchFileReader(this._entries);
- let torchscript = reader.has_record('constants.pkl');
- const version = reader.version();
- if (torchscript) {
- metadata.register(this.execution);
- this.module = torch.jit.load(reader);
- torchscript = this.module._c._has_method('forward');
- if (torchscript) {
- // console.log(this.module.graph.toString());
- torch._C._jit_pass_inline(this.module.graph);
- // console.log(this.module.graph.toString());
- }
- } else {
- const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]);
- const entries = new Map(records);
- this.module = torch.load(entries);
- }
- const name = torchscript ? 'TorchScript' : 'PyTorch';
- this.format = pytorch.Utility.format(name, version);
- delete this._model;
- delete this._entries;
- }
- };
- pytorch.Reader.ModelJson = class extends pytorch.Reader {
- static async open(context) {
- const identifier = context.identifier;
- if (identifier === 'model.json') {
- const model = await context.peek('json');
- if (model && model.mainModule) {
- const entries = new Map();
- entries.set('model.json', context.stream);
- return new pytorch.Reader.ModelJson(context, entries, model);
- }
- }
- return null;
- }
- constructor(context, entries, model) {
- super();
- this.type = 'pytorch.model.json';
- this._context = context;
- this._entries = entries;
- this._model = model;
- }
- async read(metadata) {
- pytorch.proto = await this._context.require('./pytorch-proto');
- const keys = [
- 'attributes.pkl',
- 'version',
- ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key)
- ];
- const walk = (module) => {
- if (module.torchscriptArena && module.torchscriptArena.key) {
- keys.push(module.torchscriptArena.key);
- }
- for (const submodule of module.submodules || []) {
- walk(submodule);
- }
- };
- walk(this._model.mainModule);
- const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null)));
- for (let i = 0; i < keys.length; i++) {
- if (values[i]) {
- this._entries.set(keys[i], values[i]);
- }
- }
- this.execution = new pytorch.Execution(null, metadata);
- this.execution.proto = pytorch.proto;
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- const torch = this.execution.__import__('torch');
- const reader = new torch.PyTorchFileReader(this._entries);
- if (this._model && this._model.producerName) {
- this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : '');
- }
- this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
- metadata.register(this.execution);
- this.module = torch.jit.load(reader);
- if (this.module._c._has_method('forward')) {
- // console.log(this.module.graph.toString());
- torch._C._jit_pass_inline(this.module.graph);
- // console.log(this.module.graph.toString());
- }
- delete this._context;
- delete this._model;
- delete this._entries;
- }
- };
- pytorch.Reader.IR = class extends pytorch.Reader {
- static async open(context) {
- const reader = await context.read('text', 0x100);
- if (reader && reader.length > 0) {
- const line = reader.read('\n');
- if (line.startsWith('graph(')) {
- return new pytorch.Reader.IR(context);
- }
- }
- return null;
- }
- constructor(context) {
- super();
- this.type = 'pytorch.ir';
- this.context = context;
- }
- async read(metadata) {
- this.format = 'TorchScript IR';
- this.execution = new pytorch.Execution(null, metadata);
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- // this.execution.graph;
- // context reader = await context.read('text', 0x100);
- throw new pytorch.Error('TorchScript IR parser not implemented.');
- }
- };
- pytorch.Reader.Index = class extends pytorch.Reader {
- static async open(context) {
- const obj = await context.peek('json');
- if (obj && obj.weight_map) {
- const entries = Object.entries(obj.weight_map);
- if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
- return new pytorch.Reader.Index(context, entries);
- }
- }
- return null;
- }
- constructor(context, entries) {
- super();
- this.type = 'pytorch.index';
- this.context = context;
- this._entries = entries;
- }
- async read(metadata) {
- this.format = 'PyTorch';
- const weight_map = new Map(this._entries);
- const keys = new Set(weight_map.keys());
- const files = Array.from(new Set(weight_map.values()));
- const contexts = await Promise.all(files.map((name) => this.context.fetch(name)));
- this.execution = new pytorch.Execution(null, metadata);
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- const torch = this.execution.__import__('torch');
- const archives = await Promise.all(contexts.map((context) => context.peek('zip')));
- const formats = new Set(archives.map((entries) => {
- const reader = new torch.PyTorchFileReader(entries);
- const version = reader.version();
- return pytorch.Utility.format('PyTorch', version);
- }));
- if (formats.size === 1) {
- this.format = formats.values().next().value;
- }
- const shards = archives.map((entries) => {
- return torch.load(entries);
- });
- const entries = new Map();
- for (const shard of shards) {
- for (const [key, value] of Array.from(shard)) {
- if (keys.has(key)) {
- entries.set(key, value);
- }
- }
- }
- this.module = entries;
- delete this.context;
- delete this._entries;
- }
- };
- pytorch.Reader.ExportedProgram = class extends pytorch.Reader {
- static async open(context) {
- const program = await context.peek('json');
- if (program && program.schema_version && program.graph_module) {
- return new pytorch.Reader.ExportedProgram(context, program);
- }
- if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) {
- const buffer = context.stream.peek();
- const archive_format = String.fromCharCode.apply(null, buffer);
- if (archive_format === 'pt2') {
- return new pytorch.Reader.ExportedProgram(context, null, context);
- }
- }
- return null;
- }
- constructor(context, exported_program, archive_format) {
- super();
- this.type = 'pytorch.export';
- this.context = context;
- this.archive_format = archive_format;
- this.exported_program = exported_program;
- }
- async read(metadata) {
- this.format = 'PyTorch Export';
- const f = new Map();
- const exported_programs = new Map();
- if (this.archive_format) {
- for (const name of this.context.container.entries.keys()) {
- const match = name.match(/^models\/([^/]+)\.json$/);
- if (match) {
- const [, model_name] = match;
- /* eslint-disable no-await-in-loop */
- const model = await this.context.fetch(`models/${model_name}.json`);
- const exported_program = await model.read('json');
- exported_programs.set(model_name, exported_program);
- f.set(`models/${model_name}.json`, exported_program);
- const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`, 'zip');
- f.set(`data/sample_inputs/${model_name}.pt`, sample_inputs);
- const weights_config = await this._fetch(`data/weights/${model_name}_weights_config.json`, 'json');
- if (weights_config) {
- f.set(`data/weights/${model_name}_weights_config.json`, weights_config);
- for (const payload_meta of Object.values(weights_config.config)) {
- const type = payload_meta.use_pickle ? 'zip' : 'binary';
- const weight_data = await this._fetch(`data/weights/${payload_meta.path_name}`, type);
- if (weight_data) {
- f.set(`data/weights/${payload_meta.path_name}`, weight_data);
- }
- }
- } else {
- const weights = await this._fetch(`data/weights/${model_name}.pt`, 'zip');
- f.set(`data/weights/${model_name}.pt`, weights);
- }
- const constants_config = await this._fetch(`data/constants/${model_name}_constants_config.json`, 'json');
- if (constants_config) {
- f.set(`data/constants/${model_name}_constants_config.json`, constants_config);
- for (const payload_meta of Object.values(constants_config.config)) {
- // eslint-enable no-await-in-loop
- const type = payload_meta.use_pickle ? 'zip' : 'binary';
- const constant_data = await this._fetch(`data/constants/${payload_meta.path_name}`, type);
- if (constant_data) {
- f.set(`data/constants/${payload_meta.path_name}`, constant_data);
- }
- }
- } else {
- const constants = await this._fetch(`data/constants/${model_name}.pt`);
- f.set(`data/constants/${model_name}.pt`, constants);
- }
- /* eslint-enable no-await-in-loop */
- }
- }
- const byteorder = await this._fetch('byteorder', 'text') || 'little';
- f.set('byteorder', byteorder);
- } else {
- this.version = await this._fetch('version', 'text') || '';
- this.version = this.version.split('\n').shift().trim();
- const weights = await this._fetch('serialized_state_dict.pt', 'zip') || await this._fetch('serialized_state_dict.json', 'zip');
- const constants = await this._fetch('serialized_constants.pt', 'zip') || await this._fetch('serialized_constants.json', 'zip');
- const sample_inputs = await this._fetch('serialized_example_inputs.pt', 'zip');
- f.set('models/model.json', this.exported_program);
- f.set('data/weights/model.pt', weights);
- f.set('data/constants/model.pt', constants);
- f.set('data/sample_inputs/model.pt', sample_inputs);
- exported_programs.set('', this.exported_program);
- }
- if (!this.version) {
- const versions = new Set();
- for (const exported_program of exported_programs.values()) {
- const schema_version = exported_program.schema_version;
- if (schema_version && schema_version.major && schema_version.minor) {
- versions.add(`${schema_version.major}.${schema_version.minor}`);
- }
- }
- if (versions.size === 1) {
- this.version = versions.values().next().value;
- }
- }
- this.format = this.version ? `${this.format} v${this.version}` : this.format;
- this.execution = new python.Execution();
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- metadata.register(this.execution);
- const torch = this.execution.__import__('torch');
- for (const exported_program of exported_programs.values()) {
- if (exported_program.graph_module.graph.constants) {
- // eslint-disable-next-line no-await-in-loop
- const zip = await import('./zip.js');
- const constants = exported_program.graph_module.graph.constants;
- for (const key of Object.keys(constants)) {
- const value = constants[key];
- const str = atob(value);
- const buffer = new Uint8Array(str.length);
- for (let i = 0; i < str.length; i++) {
- buffer[i] = str.charCodeAt(i);
- }
- const archive = zip.Archive.open(buffer);
- constants[key] = archive.entries;
- }
- }
- }
- delete this.exported_program;
- delete this.context;
- const pt2_contents = torch.export.pt2_archive._package.load_pt2(f);
- this.modules = pt2_contents.exported_programs;
- }
- async _fetch(name, type) {
- try {
- const context = await this.context.fetch(name);
- if (context) {
- switch (type) {
- case 'zip':
- return await context.peek('zip');
- case 'json':
- return await context.read('json');
- case 'text': {
- const reader = await context.read('text');
- if (reader) {
- return reader.read();
- }
- break;
- }
- case 'binary': {
- if (context && context.stream) {
- return context.stream.peek();
- }
- break;
- }
- default: {
- throw new pytorch.Error(`Unsupported context type '${type}.`);
- }
- }
- }
- } catch {
- // continue regardless of error
- }
- return null;
- }
- };
- pytorch.Execution = class extends python.Execution {
- constructor(sources, metadata) {
- super(sources);
- this._metadata = metadata;
- // eslint-disable-next-line consistent-this
- const execution = this;
- const torch = this.torch;
- this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => {
- const cu = new torch.jit.CompilationUnit();
- cu.execution = execution;
- const stream = f;
- const reader = flatbuffers.BinaryReader.open(stream);
- const module = torch.mobile.serialization.Module.create(reader);
- const loader = new torch._C.FlatBuffersLoader(cu);
- const cpp_module = loader.parseModule(module);
- // parse_and_initialize_jit_module
- // const mobilem = parse_and_initialize_mobile_module_for_jit(data, jit_files, jit_constants);
- // const m = jitModuleFromSourceAndConstants(mobilem._ivalue(), jit_files, jit_constants, mobilem.bytecode_version());
- // throw new pytorch.Error('torch.jit.mobile.serialization.Module not supported.');
- return torch.jit._script.wrap_cpp_module(cpp_module);
- });
- this.registerType('__torch__.torch.classes._nnapi.Compilation', class {
- constructor() {
- this.__hide__ = true;
- }
- __init__() {
- }
- init(serialized_model_tensor, parameter_buffers) {
- this.serialized_model_tensor = serialized_model_tensor;
- this.parameter_buffers = parameter_buffers;
- const buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
- /*
- let buffers = [];
- if (!pytorch.Utility.isInstance(parameter_buffers, 'torch.Value')) {
- buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
- }
- */
- const serialized_model = serialized_model_tensor.storage().data;
- this.serialized_model = new nnapi.SerializedModel(serialized_model, buffers);
- }
- run(inputs, outputs) {
- execution.variable(this.serialized_model_tensor);
- this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1;
- const type = new nnapi.Graph(this.serialized_model);
- const node = execution.graph.create(type, 0);
- execution.graph.insertNode(node);
- for (const tensor of inputs) {
- const value = execution.variable(tensor);
- node.addInput(value);
- }
- for (const tensor of outputs) {
- execution.variable(tensor, node);
- }
- }
- });
- this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class {
- __setstate__(state) {
- if (state[0] !== '2') {
- throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
- }
- const [/* pack_version */, tensors, opt_tensors] = state;
- const packed_config = tensors[0].tolist();
- this.weight = tensors[1];
- this.bias = opt_tensors[0];
- this.stride = [packed_config[1], packed_config[2]];
- this.padding = [packed_config[3], packed_config[4]];
- this.dilation = [packed_config[5], packed_config[6]];
- this.output_padding = [packed_config[7], packed_config[8]];
- this.groups = packed_config[9];
- }
- });
- this.registerType('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', class {
- __setstate__(state) {
- if (state[0] !== '2') {
- throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
- }
- const [/* pack_version */, tensors, opt_tensors] = state;
- const packed_config = tensors[0].tolist();
- this.weight = tensors[1];
- this.bias = opt_tensors[0];
- this.stride = [packed_config[1], packed_config[2]];
- this.padding = [packed_config[3], packed_config[4]];
- this.dilation = [packed_config[5], packed_config[6]];
- this.output_padding = [packed_config[7], packed_config[8]];
- this.groups = packed_config[9];
- }
- });
- this.registerType('__torch__.torch.classes.quantized.LinearPackedParamsBase', class {
- __setstate__(state) {
- [this.weight, this.bias] = state;
- }
- });
- this.registerType('__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', class {
- __setstate__(state) {
- [this.version, this.tensors, this.doubles, this.longs] = state;
- }
- });
- this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class {
- __setstate__(state) {
- [this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state;
- }
- });
- this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
- __setstate__(state) {
- [this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
- }
- });
- this.registerType('__torch__.torch.classes.xnnpack.LinearOpContext', class {
- __setstate__(state) {
- [this.weight, this.bias, this.output_min, this.output_max] = state;
- }
- });
- this.registerType('__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', class {
- __setstate__(state) {
- [this.weight, this.bias, this.stride, this.padding, this.output_padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
- }
- });
- this.registerType('__torch__.torch.classes.tensorrt.Engine', class {
- __setstate__(state) {
- [this.abi_target, this.name, this.device, this.engine, this.input_binding_names, this.output_binding_names, this.hw_compatible, this.serialized_metadata, this.target_platform] = state;
- }
- });
- const custom_classes = [
- { name: '__torch__.torch.classes._nnapi.Compilation', methods: [
- '__init__(__torch__.torch.classes._nnapi.Compilation self) -> NoneType',
- 'init(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> NoneType',
- 'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType',
- 'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType'
- ] },
- { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
- { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
- { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' },
- { name: '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', attributes: 'int version, Tensor[] tensors, float[] doubles, int[] longs', methods: [] },
- { name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' },
- { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
- { name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' },
- { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
- { name: '__torch__.torch.classes.tensorrt.Engine' }
- ];
- for (const known_type of custom_classes) {
- const prefix = new torch._C.QualifiedName(known_type.name);
- const type = torch.ClassType.create(known_type.name, this._compilation_unit, false);
- for (const known_method of known_type.methods || []) {
- const schema = new torch.FunctionSchema(known_method);
- const name = new torch._C.QualifiedName(prefix, schema.name);
- const fn = new torch._C.BuiltinOpFunction(name, schema);
- type.addMethod(fn);
- }
- if (known_type.attributes) {
- const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`);
- for (const arg of schema.arguments) {
- type.addAttribute(arg.name, arg.real_type);
- }
- }
- torch._C.registerCustomClass(type);
- }
- }
- call(target, name, args, keywords, context) {
- const ast = this.ast;
- const torch = this.torch;
- if (target instanceof ast.Name && target.id === 'torch') {
- const fn = torch.ops.aten.__getattr__(name);
- if (fn) {
- const evalArgs = args.map((arg) => this.expression(arg, context));
- return fn.__call__(...evalArgs);
- }
- }
- if (target instanceof ast.Attribute && target.value instanceof ast.Name && target.value.id === 'ops') {
- const module = torch.ops[target.attr];
- if (!module) {
- throw new pytorch.Error(`Unknown torch.ops module '${target.attr}'.`);
- }
- const fn = module.__getattr__(name);
- if (fn) {
- const evalArgs = args.map((arg) => this.expression(arg, context));
- return fn.__call__(...evalArgs);
- }
- }
- return super.call(target, name, args, keywords, context);
- }
- invoke(target, args) {
- if (target && Array.isArray(target.__bases__) && target.__bases__.length > 0 && target.__bases__[0] === this.enum.Enum) {
- const instance = new target();
- instance.value = args;
- return instance;
- }
- return super.invoke(target, args);
- }
- base(expr, context) {
- const ast = this.ast;
- if (expr instanceof ast.Name) {
- switch (expr.id) {
- case 'Enum': return this.enum.Enum;
- default: break;
- }
- }
- return this.expression(expr, context);
- }
- };
- pytorch.Reader.Package = class extends pytorch.Reader {
- constructor(entries) {
- super();
- this.type = 'pytorch.package';
- this.entries = entries;
- }
- async read(metadata) {
- this.execution = new pytorch.Execution(null, metadata);
- for (const event of this._events) {
- this.execution.on(event[0], event[1]);
- }
- const torch = this.execution.__import__('torch');
- const reader = new torch.PyTorchFileReader(this.entries);
- const version = reader.version();
- this.format = pytorch.Utility.format('PyTorch Package', version);
- this.modules = new Map();
- const records = reader.get_all_records().filter((name) => {
- if (!name.startsWith('.data/') && !name.endsWith('.py')) {
- const stream = reader.get_record(name);
- if (stream && stream.length > 2) {
- const signature = stream.peek(2);
- if (signature[0] === 0x80 && signature[1] < 7) {
- return true;
- }
- }
- }
- return false;
- });
- const entries = records.map((name) => {
- const parts = name.split('/');
- const resource = parts.pop();
- const module = parts.join('.');
- return [module, resource];
- });
- if (entries.length > 0) {
- for (const name of reader.get_all_records()) {
- if (!name.startsWith('.data/') && name.endsWith('.py')) {
- const stream = reader.get_record(name);
- const buffer = stream.peek();
- this.execution.add(name, buffer);
- }
- }
- metadata.register(this.execution);
- const importer = new torch.package.PackageImporter(reader);
- for (const entry of entries) {
- const module = importer.load_pickle(entry[0], entry[1]);
- const key = `${entry[0].replace(/\./, '/')}/${entry[1]}`;
- this.modules.set(key, module);
- }
- }
- delete this.entries;
- }
- };
- pytorch.MemoryFormat = {
- Contiguous: 0,
- Preserve: 1,
- ChannelsLast: 2,
- ChannelsLast3d: 3
- };
- pytorch.Layout = {
- Strided: 0,
- Sparse: 1,
- Mkldnn: 2
- };
- pytorch.Utility = class {
- static isTensor(obj) {
- const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
- switch (name) {
- case 'torch':
- case 'torch.cuda':
- return obj.__class__.__name__.endsWith('Tensor');
- case 'torch.nn.parameter':
- return obj.__class__.__name__ === 'Parameter';
- default:
- return false;
- }
- }
- static toTensor(obj) {
- const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
- switch (name) {
- case 'torch':
- case 'torch.cuda':
- return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
- case 'torch.nn.parameter':
- if (obj.__class__.__name__ === 'Parameter') {
- const data = obj.data;
- if (typeof obj.__name__ === 'string') {
- data.__name__ = obj.__name__;
- }
- return data;
- }
- return null;
- default:
- return null;
- }
- }
- static toType(type) {
- switch (type.kind()) {
- case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`;
- case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`;
- case 'BoolType': return 'boolean';
- case 'IntType': return 'int64';
- case 'FloatType': return 'float32';
- case 'StringType': return 'string';
- case 'ComplexType': return 'complex';
- case 'NumberType': return 'scalar';
- case 'TensorType': return 'tensor';
- case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`;
- case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`;
- case 'DeviceObjType': return 'device';
- case 'SymIntType': return 'SymInt';
- case 'ScalarTypeType': return 'ScalarType';
- case 'MemoryFormat': return 'MemoryFormat';
- case 'Layout': return 'Layout';
- case 'VarType': return type.annotation_str;
- case 'NoneType': return 'None';
- case 'AnyType': return 'object';
- case 'AnyListType': return 'list';
- case 'AnyTupleType': return 'tuple';
- case 'ClassType': return type.annotation_str;
- case 'EnumType': return type.annotation_str;
- default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
- }
- }
- static toString(ivalue) {
- if (ivalue.isInt()) {
- return ivalue.toInt();
- }
- if (ivalue.isDouble()) {
- return ivalue.toDouble();
- }
- if (ivalue.isEnum()) {
- return ivalue.toEnumHolder().name();
- }
- if (ivalue.isList()) {
- return ivalue.toList().map((item) => pytorch.Utility.toString(item));
- }
- throw new pytorch.Error(`Unsupported IValue '${ivalue.tag}.`);
- }
- static constant(node, name) {
- const kind = node.kindOf(name);
- switch (kind) {
- case 's': return node.s(name);
- case 'i': return node.i(name);
- case 'f': return node.f(name);
- case 'ss': return node.ss(name);
- case 'ival': return node.ival(name);
- default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
- }
- }
- static unique(value) {
- return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
- }
- static isObject(obj) {
- const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
- switch (type) {
- case '__torch__.torch.classes.xnnpack.LinearOpContext':
- case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
- case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
- case '__torch__.torch.classes.rnn.CellParamsBase':
- case '__torch__.torch.classes.rnn.CellParamsBase[]':
- case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
- case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
- case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
- case '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase':
- return true;
- default:
- return false;
- }
- }
- static isSubclass(value, name) {
- if (value && value.__module__ && value.__name__) {
- return name === `${value.__module__}.${value.__name__}`;
- } else if (value && value.__bases__) {
- return value.__bases__.some((obj) => pytorch.Utility.isSubclass(obj, name));
- }
- return false;
- }
- static isInstance(value, name) {
- return value && value.__class__ ? pytorch.Utility.isSubclass(value.__class__, name) : false;
- }
- static format(name, value) {
- // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
- // kProducedFileFormatVersion
- const versions = new Map([
- ['1', 'v1.3'],
- ['2', 'v1.5'], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
- ['3', 'v1.6'], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
- ['4', 'v1.6'], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
- ['5', 'v1.7'], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
- ['6', 'v1.9'], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
- ['7', 'v1.10'], // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
- ['8', 'v1.11'], // b28e696516a7f0c7a6ead6da967590ce6c1d6698 (#71486)
- ['9', 'v1.11'], // 8757e21c6a4fc00e83539aa7f9c28eb11eff53c1 (#72051)
- ['10', 'v1.12'] // 4f8b986e28736b59bc46cd0873a0f36fdaa6f5b8 (#61439)
- ]);
- value = value.toString();
- if (!versions.has(value)) {
- throw new pytorch.Error(`Unsupported '${name}' version '${value}'.`);
- }
- return `${name} ${versions.get(value)}`;
- }
- static weights(obj) {
- let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
- if (type === 'torch.jit._script.RecursiveScriptModule') {
- type = obj._c._type();
- const target = {};
- for (let i = 0; i < type.numAttributes(); i++) {
- const k = type.getAttributeName(i);
- target[k] = obj.__getattr__(k);
- }
- type = obj._c.qualified_name;
- obj = target;
- } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') {
- return null;
- }
- if (pytorch.Utility.isTensor(obj)) {
- return null;
- }
- if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) {
- const entries = Object.entries(obj);
- const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pytorch.Utility.isTensor(value));
- if (named.length > 0 && (named.length / entries.length) >= 0.8) {
- obj = new Map(entries);
- }
- }
- if (obj instanceof Map) {
- const entries = Array.from(obj).filter(([name]) => name !== '_metadata');
- const names = entries.filter(([name]) => typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1));
- if (names.length > 1 && (names.length / entries.length) >= 0.8 &&
- (entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) &&
- (!entries.every(([, value]) => Array.isArray(value)))) {
- const modules = new Map();
- for (const [name, value] of entries) {
- const separator = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.';
- const path = name.split(separator);
- let property = path.pop();
- if (path.length > 1 && path[path.length - 1] === '_packed_params') {
- property = `${path.pop()}.${property}`;
- }
- const key = path.join(separator);
- if (!modules.has(key)) {
- modules.set(key, {});
- }
- const module = modules.get(key);
- if (pytorch.Utility.isTensor(value)) {
- value.__name__ = name;
- }
- module[property] = value;
- }
- return modules;
- }
- }
- if (obj && !Array.isArray(obj) && Object(obj) === obj) {
- const modules = new Map();
- const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
- if (entries.length > 0 && entries) {
- for (const [key, value] of entries) {
- const name = key.toString();
- if (!value || Object(value) !== value || pytorch.Utility.isTensor(value) || ArrayBuffer.isView(value) || value._modules instanceof Map) {
- return null;
- }
- if (!modules.has(name)) {
- modules.set(name, {});
- }
- const module = modules.get(name);
- let tensor = false;
- const entries = value instanceof Map ? value : new Map(Object.entries(value));
- for (const [name, value] of entries) {
- if (typeof name !== 'string') {
- return null;
- }
- if (name.indexOf('.') !== -1) {
- return null;
- }
- if (name === '_metadata') {
- continue;
- }
- if (typeof value === 'string' || typeof value === 'number') {
- module[name] = value;
- continue;
- }
- if (pytorch.Utility.isTensor(value)) {
- value.__name__ = name;
- module[name] = value;
- tensor = true;
- }
- }
- if (!tensor) {
- return null;
- }
- }
- return modules;
- }
- }
- return null;
- }
- static isMetadataObject(obj) {
- if (pytorch.Utility.isInstance(obj, 'collections.OrderedDict')) {
- for (const value of obj.values()) {
- if (pytorch.Utility.isInstance(value, 'builtins.dict')) {
- const entries = Array.from(value);
- if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) {
- return false;
- }
- }
- }
- return true;
- }
- return false;
- }
- };
- nnapi.SerializedModel = class {
- constructor(serialized_model, buffers) {
- const reader = base.BinaryReader.open(serialized_model);
- this.version = reader.int32();
- if (this.version !== 1) {
- throw new pytorch.Error('Invalid NNAPI serialized model version.');
- }
- const operands = new Array(reader.int32());
- const values = new Array(reader.int32());
- this.operations = new Array(reader.int32());
- this.inputs = new Array(reader.int32());
- this.outputs = new Array(reader.int32());
- const data_types = new Map([
- [0, 'float32'],
- [1, 'int32'],
- [2, 'uint32'],
- [3, 'float32[]'],
- [4, 'int32[]'],
- [5, 'quant8_asymm[]'],
- [6, 'boolean'],
- [7, 'quant16_symm[]'],
- [8, 'float16[]'],
- [9, 'boolean[]'],
- [10, 'float16'],
- [11, 'quant8_symm_per_channel[]'],
- [12, 'quant16_asymm[]'],
- [13, 'quant8_symm[]'],
- [14, 'quant8_asymm_signed[]'],
- [16, 'model']
- ]);
- for (let i = 0; i < operands.length; i++) {
- const data_type = reader.int32();
- operands[i] = {
- index: i,
- data_type: data_types.has(data_type) ? data_types.get(data_type) : data_type,
- dimensions: new Array(reader.uint32()),
- scale: reader.float32(),
- zero_point: reader.int32()
- };
- }
- for (let i = 0; i < values.length; i++) {
- values[i] = {
- index: reader.int32(),
- source_type: reader.int32(),
- source_length: reader.uint32()
- };
- }
- for (let i = 0; i < this.operations.length; i++) {
- this.operations[i] = {
- index: reader.int32(),
- identifier: i,
- inputs: new Array(reader.uint32()),
- outputs: new Array(reader.uint32())
- };
- }
- for (const operand of operands) {
- for (let i = 0; i < operand.dimensions.length; i++) {
- operand.dimensions[i] = reader.uint32();
- }
- }
- for (const value of values) {
- const index = value.index;
- const operand = operands[index];
- switch (value.source_type) {
- case 0: { // immediate
- switch (operand.data_type) {
- case 'boolean':
- operand.value = reader.byte() ? true : false;
- reader.skip(3);
- break;
- case 'int32':
- operand.value = reader.int32();
- break;
- case 'float32':
- operand.value = reader.float32();
- break;
- case 'int32[]':
- operand.data = reader.read(value.source_length);
- break;
- case 'float32[]':
- operand.data = reader.read(value.source_length);
- break;
- default:
- throw new pytorch.Error(`Unsupported NNAPI operand type '${operand.data_type}'.`);
- }
- break;
- }
- case 2: { // numbered buffer
- if (value.source_length !== 12) {
- throw new pytorch.Error('Invalid NNAPI numbered buffer source length.');
- }
- const number = reader.uint32();
- const offset = reader.uint32();
- const operand_length = reader.uint32();
- if (number < buffers.length && buffers[number].data) {
- const storage = buffers[number];
- const data = storage.data && storage.data.peek ? storage.data.peek() : storage.data;
- operand.data = data.slice(offset, operand_length);
- }
- break;
- }
- case 3: { // numbered memory
- throw new pytorch.Error('NNAPI numbered memory buffer not implemented.');
- }
- default: {
- throw new pytorch.Error('Unsupported NNAPI value source type.');
- }
- }
- }
- for (const operation of this.operations) {
- for (let i = 0; i < operation.inputs.length; i++) {
- const index = reader.uint32();
- operation.inputs[i] = operands[index];
- }
- for (let i = 0; i < operation.outputs.length; i++) {
- const index = reader.uint32();
- operation.outputs[i] = operands[index];
- }
- }
- for (let i = 0; i < this.inputs.length; i++) {
- const index = reader.uint32();
- this.inputs[i] = operands[index];
- }
- for (let i = 0; i < this.outputs.length; i++) {
- const index = reader.uint32();
- this.outputs[i] = operands[index];
- }
- if (reader.position !== reader.length) {
- throw new pytorch.Error('Invalid NNAPI serialized model length.');
- }
- }
- };
- nnapi.Graph = class {
- constructor(model) {
- this.name = 'torch.classes._nnapi.Compilation';
- this.nodes = [];
- this.inputs = [];
- this.outputs = [];
- const values = new Map();
- values.map = (operand) => {
- if (!values.has(operand.index)) {
- const name = operand.index.toString();
- const dimensions = operand.dimensions;
- const shape = new pytorch.TensorShape(dimensions);
- let dataType = operand.data_type.replace('[]', '');
- let quantization = null;
- switch (dataType) {
- case 'quant8_asymm':
- case 'quant8_symm_per_channel':
- case 'quant8_symm':
- case 'quant8_asymm_signed[]':
- case 'quant16_asymm':
- case 'quant16_symm':
- quantization = dataType;
- dataType = dataType.indexOf('16') === -1 ? 'uint8' : 'uint16';
- break;
- default:
- break;
- }
- const type = new pytorch.TensorType(dataType, shape);
- let initializer = null;
- if (operand.data) {
- const size = dimensions.reduce((a, b) => a * b, 1);
- const tensor = {
- dtype: { __reduce__: () => dataType },
- size: () => dimensions,
- stride: () => null,
- storage_offset: () => 0,
- storage: () => ({
- dtype: { __reduce__: () => type.dataType },
- data: operand.data, size: () => size
- })
- };
- const context = new pytorch.Context();
- initializer = new pytorch.Tensor(context, null, tensor);
- }
- if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) {
- quantization = {
- type: quantization || 'linear',
- scale: [operand.scale],
- offset: [operand.zero_point]
- };
- }
- const value = new pytorch.Value(name, type, quantization, initializer);
- values.set(operand.index, value);
- }
- return values.get(operand.index);
- };
- const metadata = new nnapi.Metadata();
- for (const operation of model.operations) {
- const node = new nnapi.Node(metadata, operation, values);
- this.nodes.push(node);
- }
- for (let i = 0; i < model.inputs.length; i++) {
- const name = i.toString();
- const operand = model.inputs[i];
- const argument = new pytorch.Argument(name, [values.map(operand)]);
- this.inputs.push(argument);
- }
- for (let i = 0; i < model.outputs.length; i++) {
- const name = i.toString();
- const operand = model.outputs[i];
- const argument = new pytorch.Argument(name, [values.map(operand)]);
- this.outputs.push(argument);
- }
- }
- };
- nnapi.Node = class {
- constructor(metadata, operation, values) {
- const signature = (operation.inputs || []).map((input) => input.data_type);
- this.name = '';
- this.type = metadata.type(operation.index, signature);
- this.inputs = [];
- this.outputs = [];
- this.attributes = [];
- this.chain = [];
- if (operation.identifier !== undefined) {
- this.identifier = operation.identifier.toString();
- }
- if (Array.isArray(operation.inputs)) {
- const inputs = this.type.inputs;
- for (let i = 0; i < operation.inputs.length; i++) {
- const name = i < inputs.length ? inputs[i].name : i.toString();
- const operand = operation.inputs[i];
- if (operand.dimensions.length > 0) {
- const value = values.map(operand);
- const argument = new pytorch.Argument(name, [value]);
- this.inputs.push(argument);
- } else if (name === 'activation') {
- const activation = new Map([[1, 19], [2, 20], [3, 21]]).get(operand.value) || 0;
- if (activation !== 0) {
- this.chain.push(new nnapi.Node(metadata, { index: activation }));
- }
- } else {
- const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false);
- this.inputs.push(attribute);
- }
- }
- }
- if (Array.isArray(operation.outputs)) {
- const outputs = this.type.outputs;
- for (let i = 0; i < operation.outputs.length; i++) {
- const name = i < outputs.length ? outputs[i].name : i.toString();
- const operand = operation.outputs[i];
- const value = values.map(operand);
- const argument = new pytorch.Argument(name, [value]);
- this.outputs.push(argument);
- }
- }
- }
- };
- nnapi.Metadata = class {
- constructor() {
- this._types = new Map();
- // https://developer.android.com/ndk/reference/group/neural-networks
- // https://github.com/pytorch/pytorch/commits/master/torch/backends/_nnapi/serializer.py
- this.register(0, 'ADD', '', ['A', 'B'], [['activation', 'int32']], ['C']);
- this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
- this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
- this.register(2, 'CONCATENATION');
- this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
- this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
- this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
- this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
- this.register(5, 'DEPTH_TO_SPACE');
- this.register(6, 'DEQUANTIZE');
- this.register(7, 'EMBEDDING_LOOKUP');
- this.register(8, 'FLOOR');
- this.register(9, 'FULLY_CONNECTED', 'Layer', ['input', 'weights', 'bias'], [['activation', 'int32']], ['output']);
- this.register(10, 'HASHTABLE_LOOKUP');
- this.register(11, 'L2_NORMALIZATION');
- this.register(12, 'L2_POOL_2D', 'Pool');
- this.register(13, 'LOCAL_RESPONSE_NORMALIZATION');
- this.register(14, 'LOGISTIC');
- this.register(15, 'LSH_PROJECTION');
- this.register(16, 'LSTM', 'Layer');
- this.register(17, 'MAX_POOL_2D', 'Pool');
- this.register(18, 'MUL');
- this.register(19, 'RELU', 'Activation', ['input'], [], ['output']);
- this.register(20, 'RELU1', 'Activation');
- this.register(21, 'RELU6', 'Activation');
- this.register(22, 'RESHAPE', 'Shape', ['input', 'shape'], [], ['output']);
- this.register(23, 'RESIZE_BILINEAR');
- this.register(24, 'RNN', 'Layer');
- this.register(25, 'SOFTMAX', 'Activation');
- this.register(26, 'SPACE_TO_DEPTH');
- this.register(27, 'SVDF');
- this.register(28, 'TANH');
- this.register(29, 'BATCH_TO_SPACE_ND');
- this.register(30, 'DIV');
- this.register(31, 'MEAN');
- this.register(32, 'PAD');
- this.register(33, 'SPACE_TO_BATCH_ND');
- this.register(34, 'SQUEEZE');
- this.register(35, 'STRIDED_SLICE');
- this.register(36, 'SUB');
- this.register(37, 'TRANSPOSE');
- this.register(38, 'ABS');
- this.register(39, 'ARGMAX');
- this.register(40, 'ARGMIN');
- this.register(41, 'AXIS_ALIGNED_BBOX_TRANSFORM');
- this.register(42, 'BIDIRECTIONAL_SEQUENCE_LSTM');
- this.register(43, 'BIDIRECTIONAL_SEQUENCE_RNN');
- this.register(44, 'BOX_WITH_NMS_LIMIT');
- this.register(45, 'CAST');
- this.register(46, 'CHANNEL_SHUFFLE');
- this.register(47, 'DETECTION_POSTPROCESSING');
- this.register(48, 'EQUAL');
- this.register(49, 'EXP');
- this.register(50, 'EXPAND_DIMS');
- this.register(51, 'GATHER');
- this.register(52, 'GENERATE_PROPOSALS');
- this.register(53, 'GREATER');
- this.register(54, 'GREATER_EQUAL');
- this.register(55, 'GROUPED_CONV_2D');
- this.register(56, 'HEATMAP_MAX_KEYPOINT');
- this.register(57, 'INSTANCE_NORMALIZATION');
- this.register(58, 'LESS');
- this.register(59, 'LESS_EQUAL');
- this.register(60, 'LOG');
- this.register(61, 'LOGICAL_AND');
- this.register(62, 'LOGICAL_NOT');
- this.register(63, 'LOGICAL_OR');
- this.register(64, 'LOG_SOFTMAX');
- this.register(65, 'MAXIMUM');
- this.register(66, 'MINIMUM');
- this.register(67, 'NEG');
- this.register(68, 'NOT_EQUAL');
- this.register(69, 'PAD_V2');
- this.register(70, 'POW');
- this.register(71, 'PRELU');
- this.register(72, 'QUANTIZE');
- this.register(73, 'QUANTIZED_16BIT_LSTM');
- this.register(74, 'RANDOM_MULTINOMIAL');
- this.register(75, 'REDUCE_ALL');
- this.register(76, 'REDUCE_ANY');
- this.register(77, 'REDUCE_MAX');
- this.register(78, 'REDUCE_MIN');
- this.register(79, 'REDUCE_PROD');
- this.register(80, 'REDUCE_SUM');
- this.register(81, 'ROI_ALIGN');
- this.register(82, 'ROI_POOLING');
- this.register(83, 'RSQRT');
- this.register(84, 'SELECT');
- this.register(85, 'SIN');
- this.register(86, 'SLICE');
- this.register(87, 'SPLIT');
- this.register(88, 'SQRT');
- this.register(89, 'TILE');
- this.register(90, 'TOPK_V2');
- this.register(91, 'TRANSPOSE_CONV_2D', 'Layer');
- this.register(92, 'UNIDIRECTIONAL_SEQUENCE_LSTM', 'Layer');
- this.register(93, 'UNIDIRECTIONAL_SEQUENCE_RNN', 'Layer');
- this.register(94, 'RESIZE_NEAREST_NEIGHBOR');
- this.register(95, 'QUANTIZED_LSTM', 'Layer');
- this.register(96, 'IF');
- this.register(97, 'WHILE');
- this.register(98, 'ELU', 'Activation');
- this.register(99, 'HARD_SWISH', 'Activation');
- this.register(100, 'FILL');
- this.register(101, 'RANK');
- }
- register(index, name, category, inputs, attributes, outputs) {
- inputs = inputs || [];
- outputs = outputs || [];
- attributes = attributes || [];
- const type = {};
- type.name = name;
- type.inputs = inputs.map((name) => ({ name, type: 'Tensor' }));
- type.inputs = type.inputs.concat(attributes.map(([name, type]) => ({ name, type })));
- type.outputs = outputs.map((name) => ({ name, type: 'Tensor' }));
- if (category) {
- type.category = category;
- }
- if (!this._types.has(index)) {
- this._types.set(index, []);
- }
- this._types.get(index).push(type);
- }
- type(index, signature) {
- if (!this._types.has(index)) {
- this._types.set(index, { name: index.toString(), inputs: [], outputs: [], attributes: [] });
- }
- const types = this._types.get(index);
- for (const type of types) {
- const inputs = type.inputs;
- if (signature.length < inputs.length) {
- if (inputs.every((input, i) => input.type === undefined || input.type === 'Tensor' || input.type === signature[i])) {
- return type;
- }
- }
- }
- return types[0];
- }
- };
- pytorch.Metadata = class {
- static async open(context) {
- if (!pytorch.Metadata._metadata) {
- let data = null;
- try {
- data = await context.request('pytorch-metadata.json');
- } catch {
- // continue regardless of error
- }
- pytorch.Metadata._metadata = new pytorch.Metadata(data);
- }
- return pytorch.Metadata._metadata;
- }
- constructor(data) {
- this._types = new Map();
- this._attributes = new Map();
- this._index = new Map();
- if (data) {
- const items = JSON.parse(data);
- for (const item of items) {
- const index = item.name.indexOf('(');
- const key = index === -1 ? item.name : item.name.substring(0, index);
- this._types.set(key, item);
- }
- }
- }
- add(name, value) {
- this._types.set(name, value);
- }
- type(name) {
- return this._types.get(name);
- }
- attribute(type, name) {
- const key = `${type}:${name}`;
- if (!this._attributes.has(key)) {
- this._attributes.set(key, null);
- const metadata = this.type(type);
- if (metadata) {
- if (metadata.inputs) {
- for (const input of metadata.inputs) {
- this._attributes.set(`${type}:${input.name}`, input);
- }
- }
- if (metadata.attributes) {
- for (const attribute of metadata.attributes) {
- this._attributes.set(`${type}:${attribute.name}`, attribute);
- }
- }
- }
- }
- return this._attributes.get(key);
- }
- register(execution) {
- const torch = execution.register('torch');
- const registry = torch._C.getRegistry();
- const modules = new Set();
- for (const [name, type] of this._types) {
- if (name.indexOf('::') !== -1) {
- const schema = torch.FunctionSchema.parse(type.name);
- if (type.category) {
- schema.category = type.category;
- }
- schema.setAliasAnalysis('FROM_SCHEMA');
- const op = new torch._C.Operator(schema);
- registry.registerOperator(op);
- modules.add(type.name.split('::')[0]);
- }
- }
- for (const module of modules) {
- const namespace = new torch._ops._OpNamespace(module);
- execution.register(`torch.ops.${module}`, namespace);
- }
- }
- };
- numpy.Tensor = class {
- constructor(array) {
- this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
- this.stride = array.strides.map((stride) => stride / array.itemsize);
- this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
- this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
- }
- };
- numpy.TensorType = class {
- constructor(dataType, shape) {
- this.dataType = dataType || '?';
- this.shape = shape;
- }
- toString() {
- return this.dataType + this.shape.toString();
- }
- };
- numpy.TensorShape = class {
- constructor(dimensions) {
- this.dimensions = dimensions;
- }
- toString() {
- return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
- }
- };
- pytorch.Error = class extends Error {
- constructor(message) {
- super(message);
- this.name = 'Error loading PyTorch model.';
- }
- };
- export const Metadata = pytorch.Metadata;
- export const ModelFactory = pytorch.ModelFactory;
|