pytorch.js 126 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755
  1. // Experimental
  2. import * as base from './base.js';
  3. import * as flatbuffers from './flatbuffers.js';
  4. import * as python from './python.js';
  5. const pytorch = {};
  6. const nnapi = {};
  7. const numpy = {};
  8. pytorch.ModelFactory = class {
  9. async match(context) {
  10. const reader = await pytorch.Reader.open(context);
  11. if (reader) {
  12. return context.set(reader.type, reader);
  13. }
  14. return null;
  15. }
  16. filter(context, match) {
  17. if (context.type === 'pytorch.export' && match.type === 'pytorch.zip') {
  18. return false;
  19. }
  20. if (context.type === 'pytorch.index' && match.type === 'pytorch.zip') {
  21. return false;
  22. }
  23. if (context.type === 'pytorch.model.json' && match.type === 'pytorch.data.pkl') {
  24. return false;
  25. }
  26. if (context.type === 'pytorch.model.json' && match.type === 'pickle') {
  27. return false;
  28. }
  29. return true;
  30. }
  31. async open(context) {
  32. const metadata = await pytorch.Metadata.open(context);
  33. const target = context.value;
  34. target.on('resolve', (sender, name) => {
  35. context.error(new pytorch.Error(`Unknown type name '${name}'.`), false);
  36. });
  37. await target.read(metadata);
  38. if (!target.format || (!target.modules && !target.module)) {
  39. throw new pytorch.Error("Reader not implemented.");
  40. }
  41. return new pytorch.Model(metadata, target);
  42. }
  43. };
  44. pytorch.Model = class {
  45. constructor(metadata, target) {
  46. this.format = target.format;
  47. this.producer = target.producer || '';
  48. this.modules = [];
  49. if (target.module) {
  50. const graph = new pytorch.Graph(target.execution, metadata, null, '', target.module);
  51. this.modules.push(graph);
  52. delete target.execution;
  53. } else if (target.modules) {
  54. for (const [name, value] of target.modules) {
  55. const graph = new pytorch.Graph(target.execution, metadata, null, name, value);
  56. this.modules.push(graph);
  57. delete target.execution;
  58. }
  59. }
  60. }
  61. };
  62. pytorch.Graph = class {
  63. constructor(execution, metadata, type, name = '', module = null) {
  64. this.nodes = [];
  65. this.inputs = [];
  66. this.outputs = [];
  67. this.name = name;
  68. this.type = type;
  69. const context = new pytorch.Context(execution, metadata);
  70. context.values.map = (name, type, tensor) => {
  71. if (tensor) {
  72. return new pytorch.Value(name, type, null, tensor);
  73. }
  74. if (!context.values.has(name)) {
  75. context.values.set(name, new pytorch.Value(name, type, null, tensor));
  76. } else if (type || tensor) {
  77. throw new pytorch.Error(`Duplicate value '${name}'.`);
  78. }
  79. return context.values.get(name);
  80. };
  81. const torch = execution ? execution.torch : null;
  82. if (torch && module instanceof torch.jit._script.RecursiveScriptModule && module._c._has_method('forward')) {
  83. const initializers = new Map();
  84. const graph = module.graph;
  85. const constants = module.code_with_constants[1].const_mapping;
  86. if (constants) {
  87. for (const [key, value] of constants) {
  88. const name = `CONSTANTS.${key}`;
  89. if (pytorch.Utility.isTensor(value)) {
  90. initializers.set(value, new pytorch.Tensor(context, name, value));
  91. } else if (pytorch.Utility.isObject(value)) {
  92. initializers.set(value, value);
  93. } else {
  94. // throw new pytorch.Error('Unsupported constant.');
  95. }
  96. }
  97. }
  98. const deleted = new Set();
  99. const param_node = graph.param_node();
  100. const self = param_node && param_node.outputs().length > 0 && param_node.outputs()[0].type() === module._c._type() ? param_node.outputs()[0] : null;
  101. if (self) {
  102. const getattr = (value) => {
  103. if (value.value === undefined) {
  104. const node = value.node();
  105. if (node.kind() === 'prim::GetAttr') {
  106. const [input] = node.inputs();
  107. getattr(input);
  108. if (input.value !== undefined) {
  109. const name = node.s('name');
  110. value.value = input.value.__getattr__(name);
  111. value.identifier = input.identifier ? `${input.identifier}.${name}` : name;
  112. }
  113. }
  114. if (node === param_node && value === param_node.outputs()[0]) {
  115. value.value = module;
  116. value.identifier = '';
  117. }
  118. }
  119. };
  120. for (const node of graph.nodes()) {
  121. for (const input of node.inputs()) {
  122. getattr(input, node);
  123. }
  124. }
  125. const delattr = (value) => {
  126. for (const use of Array.from(value.uses())) {
  127. const node = use.user;
  128. if (node.kind() === 'prim::GetAttr') {
  129. for (const output of node.outputs()) {
  130. delattr(output);
  131. }
  132. // deleted.add(node);
  133. node.destroy();
  134. }
  135. }
  136. };
  137. delattr(param_node.outputs()[0], '');
  138. }
  139. for (const node of graph.nodes()) {
  140. if (node.kind() === 'prim::Constant' && node.outputs().length === 1) {
  141. const output = node.output();
  142. output.identifier = output.debugName();
  143. if (node.hasAttribute('value')) {
  144. const kind = node.kindOf('value');
  145. output.value = node[kind]('value');
  146. } else if (node.output().type() instanceof torch.NoneType) {
  147. output.value = null;
  148. }
  149. // deleted.add(node);
  150. node.destroy();
  151. }
  152. }
  153. for (const node of graph.nodes()) {
  154. if (node.kind() === 'prim::TupleUnpack') {
  155. const value = node.inputs()[0].value;
  156. if (Array.isArray(value) && value.length === node.outputs().length && value.every((value) => typeof value === 'number' || typeof value === 'string' || typeof value === 'boolean')) {
  157. for (let i = 0; i < node.outputs().length; i++) {
  158. const output = node.outputs()[i];
  159. output.value = value[i];
  160. }
  161. // deleted.add(node);
  162. node.destroy();
  163. }
  164. }
  165. }
  166. for (const node of graph.nodes()) {
  167. if (node.kind() === 'prim::ListConstruct' &&
  168. node.inputs().every((value) => typeof value.value === 'number' || typeof value.value === 'string' || typeof value.value === 'boolean') &&
  169. node.outputs().every((value) => value.uses().every((use) => use.user.kind() !== 'prim::CallMethod'))) {
  170. node.outputs()[0].value = node.inputs().map((value) => value.value);
  171. // deleted.add(node);
  172. node.destroy();
  173. }
  174. }
  175. for (const v of graph.inputs()) {
  176. if (self.uses().length === 0 && v === self) {
  177. continue;
  178. }
  179. const identifier = pytorch.Utility.unique(v);
  180. const name = v.debugName() || identifier;
  181. const value = context.values.map(identifier);
  182. this.inputs.push(new pytorch.Argument(name, [value]));
  183. }
  184. for (const value of graph.outputs()) {
  185. const identifier = pytorch.Utility.unique(value);
  186. this.outputs.push(new pytorch.Argument(identifier, [context.values.map(identifier)]));
  187. }
  188. for (const node of graph.nodes()) {
  189. if (deleted.has(node)) {
  190. continue;
  191. }
  192. if (node === graph.param_node() ||
  193. node === graph.return_node()) {
  194. continue;
  195. }
  196. if (node.kind() === 'prim::ListConstruct') {
  197. if (node.outputs().length === 1 &&
  198. node.outputs().every((output) => output.uses().length === 1) &&
  199. node.inputs().every((input) => pytorch.Utility.isTensor(input.value) || input instanceof torch.Value)) {
  200. continue;
  201. }
  202. }
  203. this.nodes.push(new pytorch.Node(execution, metadata, null, null, node, initializers, context));
  204. }
  205. } else if (torch && module instanceof torch.export.exported_program.ExportedProgram && module.graph) {
  206. const exported_program = module;
  207. const graph = exported_program.graph;
  208. const graph_module = exported_program.graph_module;
  209. const inputs_to_parameters = exported_program.graph_signature.inputs_to_parameters;
  210. const inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers;
  211. const inputs_to_lifted_tensor_constants = exported_program.graph_signature.inputs_to_lifted_tensor_constants;
  212. const nodes = new Map(graph.nodes.map((node) => [node.name, node]));
  213. for (const obj of graph.nodes) {
  214. if (obj.op === 'placeholder') {
  215. if (inputs_to_parameters.has(obj.name)) {
  216. const key = inputs_to_parameters.get(obj.name);
  217. const parameter = exported_program.state_dict.get(key);
  218. const tensor = parameter ? (parameter.data || parameter) : obj.meta.get('val');
  219. const initializer = new pytorch.Tensor(context, key, tensor);
  220. const value = new pytorch.Value(key, null, null, initializer);
  221. context.values.set(obj, value);
  222. } else if (inputs_to_buffers.has(obj.name)) {
  223. const key = inputs_to_buffers.get(obj.name);
  224. const buffer = exported_program.state_dict.get(key);
  225. const tensor = buffer || obj.meta.get('val');
  226. const initializer = new pytorch.Tensor(context, key, tensor);
  227. const value = new pytorch.Value(key, null, null, initializer);
  228. context.values.set(obj, value);
  229. } else if (inputs_to_lifted_tensor_constants.has(obj.name)) {
  230. const key = inputs_to_lifted_tensor_constants.get(obj.name);
  231. const constant = exported_program.constants.get(key);
  232. const tensor = constant && constant.data ? constant.data : obj.meta.get('val');
  233. const initializer = new pytorch.Tensor(context, key, tensor);
  234. const value = new pytorch.Value(key, null, null, initializer);
  235. context.values.set(obj, value);
  236. }
  237. if (obj.users.size > 1 && context.values.has(obj)) {
  238. const node = new pytorch.Node(execution, metadata, obj.name, null, obj, null, context);
  239. this.nodes.push(node);
  240. context.values.set(obj, node.outputs[0].value[0]);
  241. }
  242. }
  243. }
  244. context.graph(this, graph_module, false);
  245. for (const input_spec of exported_program.graph_signature.user_inputs) {
  246. if (nodes.has(input_spec)) {
  247. const node = nodes.get(input_spec);
  248. const value = context.value(node);
  249. const argument = new pytorch.Argument(input_spec, [value]);
  250. this.inputs.push(argument);
  251. }
  252. }
  253. } else if (torch && module instanceof torch.fx.GraphModule && module.graph) {
  254. context.graph(this, module, true);
  255. } else if (pytorch.Utility.isTensor(module)) {
  256. const node = new pytorch.Node(execution, metadata, null, type, { value: module }, null, context);
  257. this.nodes.push(node);
  258. } else {
  259. const weights = this.type === 'weights' ? module : pytorch.Utility.weights(module);
  260. if (weights) {
  261. this.name = !this.name && typeof module.__name__ === 'string' ? module.__name__ : this.name;
  262. for (const [name, module] of weights) {
  263. const node = new pytorch.Node(execution, metadata, name, 'Weights', module, null, context);
  264. this.nodes.push(node);
  265. }
  266. } else {
  267. const modules = Array.isArray(module) && module.every((module) => module && !pytorch.Utility.isTensor(module) && (module._modules !== undefined || module.__class__)) ? module : [module];
  268. for (const module of modules) {
  269. const type = this.type === 'weights' ? 'Weights' : null;
  270. const node = new pytorch.Node(execution, metadata, null, type, module, null, context);
  271. this.nodes.push(node);
  272. }
  273. }
  274. }
  275. }
  276. };
  277. pytorch.Argument = class {
  278. constructor(name, value, type = null, visible = true) {
  279. this.name = name;
  280. this.value = value;
  281. this.type = type;
  282. this.visible = visible;
  283. }
  284. };
  285. pytorch.Value = class Value {
  286. constructor(name, type, quantization, initializer = null) {
  287. if (typeof name !== 'string') {
  288. throw new pytorch.Error(`Invalid value identifier '${JSON.stringify(name)}'.`);
  289. }
  290. this.name = name;
  291. this.type = initializer && initializer.type ? initializer.type : type || null;
  292. this.quantization = quantization;
  293. this.initializer = initializer;
  294. }
  295. };
  296. pytorch.Node = class {
  297. constructor(execution, metadata, name, type, obj, initializers, context, stack) {
  298. const torch = execution ? execution.torch : null;
  299. const builtins = execution ? execution.builtins : null;
  300. this.name = name || '';
  301. this.nodes = [];
  302. this.attributes = [];
  303. this.inputs = [];
  304. this.outputs = [];
  305. this.blocks = [];
  306. this.metadata = [];
  307. if (torch && obj instanceof torch.Node) {
  308. const node = obj;
  309. const kind = node.kind();
  310. const schema = node.schema();
  311. const inputs = node.inputs();
  312. const outputs = node.outputs();
  313. this.type = {
  314. name: kind.indexOf('::') === -1 ? kind : kind.split('::').pop().split('.')[0],
  315. identifier: kind
  316. };
  317. if (schema && schema.category) {
  318. this.type.category = schema.category;
  319. }
  320. const getAttribute = (node, name) => {
  321. const kind = node.kindOf(name);
  322. let value = null;
  323. let type = null;
  324. switch (kind) {
  325. case 's': value = node.s(name); type = 'string'; break;
  326. case 'i': value = node.i(name); type = 'int64'; break;
  327. case 'f': value = node.f(name); type = 'float32'; break;
  328. case 't': value = node.t(name); type = 'tensor'; break;
  329. case 'ss': value = node.ss(name); type = 'string[]'; break;
  330. case 'tys': value = node.tys(name).map((ty) => pytorch.Utility.toType(ty)); type = 'type[]'; break;
  331. case 'ival': value = node.ival(name); break;
  332. default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
  333. }
  334. return [type, value];
  335. };
  336. for (const name of node.attributeNames()) {
  337. const [type, value] = getAttribute(node, name);
  338. const attribute = new pytorch.Argument(name, value, type);
  339. this.attributes.push(attribute);
  340. }
  341. const mapTensor = (value) => {
  342. if (value.identifier && pytorch.Utility.isTensor(value.value)) {
  343. const identifier = value.identifier;
  344. if (!context.values.has(identifier)) {
  345. const tensor = new pytorch.Tensor(context, identifier, value.value);
  346. context.values.set(identifier, new pytorch.Value(identifier, null, null, tensor));
  347. }
  348. return context.values.map(identifier);
  349. }
  350. let initializer = null;
  351. let identifier = value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
  352. if (value.value) {
  353. const obj = value.value;
  354. const hide = obj.__parent__ ? obj.__parent__.__hide__ : true;
  355. initializer = hide ? initializers.get(obj) : null;
  356. identifier = initializer ? initializer.name : identifier;
  357. }
  358. if (initializer) {
  359. return new pytorch.Value(identifier, null, null, initializer);
  360. }
  361. return context.values.map(identifier);
  362. };
  363. for (let i = 0; i < inputs.length; i++) {
  364. const input = inputs[i];
  365. const arg = schema && schema.arguments && i < schema.arguments.length ? schema.arguments[i] : null;
  366. const name = arg && arg.name ? arg.name : i.toString();
  367. let type = arg ? arg.real_type : null;
  368. let array = false;
  369. if (type instanceof torch.ListType) {
  370. array = true;
  371. type = type.getElementType();
  372. }
  373. let argument = null;
  374. if (type && type instanceof torch.ClassType) {
  375. const obj = input.value;
  376. if (!array && initializers.has(obj)) {
  377. const node = new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context);
  378. argument = new pytorch.Argument(name, node, 'object');
  379. } else if (array && Array.isArray(obj) && obj.every((obj) => initializers.has(obj))) {
  380. const node = obj.map((obj) => new pytorch.Node(execution, metadata, name, type.qualified_name(), obj, initializers, context));
  381. argument = new pytorch.Argument(name, node, 'object[]');
  382. } else if (array && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 && input.node().inputs().every((input) => input.value)) {
  383. const node = input.node().inputs().map((input) => new pytorch.Node(execution, metadata, name, null, input.value, initializers, context));
  384. argument = new pytorch.Argument(name, node, 'object[]');
  385. } else if (input.value === undefined) {
  386. const identifier = pytorch.Utility.unique(input);
  387. const value = context.values.map(identifier);
  388. argument = new pytorch.Argument(name, [value]);
  389. } else {
  390. const node = new pytorch.Node(execution, metadata, null, null, input.value, initializers, context);
  391. argument = new pytorch.Argument(name, node, 'object');
  392. }
  393. } else if ((input.type() instanceof torch.TensorType || (input.type() instanceof torch.OptionalType && input.type().getElementType() instanceof torch.TensorType)) && pytorch.Utility.isTensor(input.value)) {
  394. const value = mapTensor(input);
  395. argument = new pytorch.Argument(name, [value]);
  396. } else if (input instanceof torch.Value && !pytorch.Utility.isTensor(input.value)) {
  397. if (input.value !== undefined) {
  398. if (Array.isArray(input.value) && input.value.every((value) => pytorch.Utility.isTensor(value))) {
  399. continue;
  400. }
  401. const type = input.type() ? pytorch.Utility.toType(input.type()) : null;
  402. let value = input.value;
  403. if (value && value instanceof torch._C.IValue) {
  404. value = pytorch.Utility.toString(value);
  405. }
  406. if (value && value instanceof builtins.complex) {
  407. value = new base.Complex(value.real, value.imag);
  408. }
  409. argument = new pytorch.Argument(name, value, type || 'attribute');
  410. } else if (input.type() instanceof torch.ListType) {
  411. if (input.node() && input.node().kind() === 'prim::ListConstruct' && input.uses().length === 1 &&
  412. input.node().inputs().every((value) => value instanceof torch.Value || value.type() instanceof torch.IntType || value.type() instanceof torch.FloatType || value.type() instanceof torch.StringType || value.type() instanceof torch.ComplexType || value.type() instanceof torch.TensorType)) {
  413. const list = input.node().inputs();
  414. const args = list.map((value) => {
  415. if (pytorch.Utility.isTensor(value.value)) {
  416. return mapTensor(value);
  417. }
  418. if (value.value !== undefined) {
  419. return value.value;
  420. }
  421. const identifier = pytorch.Utility.unique(value);
  422. return context.values.map(identifier);
  423. });
  424. const type = list.every((value) => (pytorch.Utility.isTensor(value.value)) || value.value === null) ? null : pytorch.Utility.toType(input.type());
  425. argument = new pytorch.Argument(name, args, type);
  426. } else {
  427. const identifier = pytorch.Utility.unique(input);
  428. argument = new pytorch.Argument(name, [context.values.map(identifier)]);
  429. }
  430. } else if (input.type() instanceof torch.StringType && typeof input.value === 'string') {
  431. argument = new pytorch.Argument(name, input.value, 'string');
  432. } else if (input.type() instanceof torch.BoolType && (typeof input.value === 'boolean' || input.value === 0 || input.value === 1)) {
  433. argument = new pytorch.Argument(name, Boolean(input.value), 'boolean');
  434. } else if (input.type() instanceof torch.IntType && typeof input.value === 'number') {
  435. argument = new pytorch.Argument(name, input.value, 'int64');
  436. } else if (input.type() instanceof torch.FloatType && typeof input.value === 'number') {
  437. argument = new pytorch.Argument(name, input.value, 'float32');
  438. } else if (input.type() instanceof torch.NoneType && input.value === null) {
  439. argument = new pytorch.Argument(name, null, 'attribute');
  440. } else {
  441. const identifier = pytorch.Utility.unique(input);
  442. const value = context.values.map(identifier);
  443. argument = new pytorch.Argument(name, [value]);
  444. }
  445. } else if (pytorch.Utility.isTensor(input.value) || input.value === undefined || input.value === null) {
  446. let list = [input];
  447. if (input.node() && node !== input.node() &&
  448. input.node().kind() === 'prim::ListConstruct' &&
  449. input.uses().length === 1 &&
  450. input.node().inputs().every((input) => pytorch.Utility.isTensor(input.value))) {
  451. list = input.node().inputs();
  452. }
  453. const args = list.map((input) => {
  454. let initializer = null;
  455. let identifier = pytorch.Utility.unique(input);
  456. if (input.value) {
  457. const value = input.value;
  458. const hide = value.__parent__ ? value.__parent__.__hide__ : true;
  459. initializer = hide ? initializers.get(value) : null;
  460. identifier = initializer ? initializer.name : identifier;
  461. }
  462. if (initializer) {
  463. return new pytorch.Value(identifier, null, null, initializer);
  464. }
  465. return context.values.map(identifier);
  466. });
  467. argument = new pytorch.Argument(name, args);
  468. } else if (Array.isArray(input.value) && input.value.some((value) => value instanceof torch.Value)) {
  469. const args = input.value.map((value) => {
  470. if (value instanceof torch.Value) {
  471. const identifier = pytorch.Utility.unique(value);
  472. return context.values.map(identifier);
  473. }
  474. return value;
  475. });
  476. argument = new pytorch.Argument(name, args, pytorch.Utility.toType(type));
  477. } else {
  478. throw new pytorch.Error('Unsupported input value');
  479. }
  480. this.inputs.push(argument);
  481. }
  482. for (let i = 0; i < outputs.length; i++) {
  483. const output = outputs[i];
  484. const ret = schema && schema.returns && i < schema.returns.length ? schema.returns[i] : null;
  485. if (ret && ret.name) {
  486. name = ret.name;
  487. } else {
  488. name = i === 0 && outputs.length === 1 ? 'output' : `${i}`;
  489. }
  490. let list = [output];
  491. if (output.uses().length === 1 &&
  492. output.uses()[0].user &&
  493. output.uses()[0].user.kind() === 'prim::ListUnpack' &&
  494. output.uses()[0].user.outputs().every((output) => pytorch.Utility.isTensor(output.value))) {
  495. list = output.uses()[0].user.outputs();
  496. }
  497. const args = list.map((output) => context.values.map(pytorch.Utility.unique(output)));
  498. const argument = new pytorch.Argument(name, args);
  499. this.outputs.push(argument);
  500. }
  501. const blocks = node.blocks();
  502. for (let i = 0; i < blocks.length; i++) {
  503. const block = blocks[i];
  504. const nodes = Array.from(block.nodes());
  505. if (nodes.length > 0) {
  506. const name = `block${i.toString()}`;
  507. const graph = { name: '', nodes: [] }; // new pytorch.Graph(execution, metadata, null, name, blocks[i]);
  508. const argument = new pytorch.Argument(name, graph, 'graph');
  509. this.blocks.push(argument);
  510. }
  511. }
  512. const sourceRange = node.sourceRange();
  513. if (sourceRange) {
  514. this.metadata.push(new pytorch.Argument('source', sourceRange.toString().replace(/^at\s/, '').replace(/\.$/, ''), 'attribute'));
  515. if (sourceRange.source()) {
  516. const orig = sourceRange.source().findSourceRangeThatGenerated(sourceRange);
  517. if (orig) {
  518. this.metadata.push(new pytorch.Argument('generated', orig.toString(), 'attribute'));
  519. }
  520. }
  521. }
  522. } else if (torch && obj instanceof torch.fx.node.Node) {
  523. if (obj.op === 'call_function') {
  524. let name = null;
  525. const target = obj.target;
  526. if (target instanceof torch._ops.OpOverload) {
  527. name = target.name();
  528. } else if (target instanceof torch._ops.HigherOrderOperator) {
  529. name = `${target.namespace}::${target.name}`;
  530. } else if (builtins.isinstance(target, builtins.function)) {
  531. name = target.__name__;
  532. } else if (typeof target === 'string') {
  533. // Handle unresolved operators
  534. const match = target.match(/^torch\.ops\.([^.]+)\.(.+)$/);
  535. if (!match) {
  536. throw new pytorch.Error(`Unsupported target '${target}'.`);
  537. }
  538. const [, namespace, opname] = match;
  539. name = `${namespace}::${opname}`;
  540. } else {
  541. throw new pytorch.Error(`Unsupported target '${target}'.`);
  542. }
  543. this.type = {
  544. identifier: name,
  545. name: name.indexOf('::') === -1 ? name : name.split('::').pop().split('.')[0]
  546. };
  547. const schema = obj.target._schema;
  548. if (schema && schema.category) {
  549. this.type.category = schema.category;
  550. }
  551. let args = obj.args.map((arg, index) => {
  552. if (!schema) {
  553. return ['', arg];
  554. }
  555. if (Array.isArray(schema.arguments) && index < schema.arguments.length) {
  556. return [schema.arguments[index].name, arg];
  557. }
  558. if (schema.is_vararg) {
  559. return ['', arg];
  560. }
  561. throw new pytorch.Error('Unsupported schema argument.');
  562. });
  563. const inputs = new Map((schema ? schema.arguments : []).map((arg) => [arg.name, arg]));
  564. args = args.concat(Array.from(obj.kwargs));
  565. for (const [name, arg] of args) {
  566. let type = inputs.has(name) ? pytorch.Utility.toType(inputs.get(name).real_type) : null;
  567. if (arg instanceof torch.fx.node.Node) {
  568. let argument = null;
  569. if (arg.op === 'get_attr' && arg.users.size === 1) {
  570. const subgraph = context.function(arg);
  571. if (subgraph) {
  572. argument = new pytorch.Argument(name, subgraph, 'function');
  573. }
  574. }
  575. if (!argument) {
  576. const value = context.value(arg);
  577. argument = new pytorch.Argument(name, [value]);
  578. }
  579. this.inputs.push(argument);
  580. } else if (Array.isArray(arg) && arg.every((arg) => arg instanceof torch.fx.node.Node || arg === null)) {
  581. const list = arg.map((arg) => arg === null ? null : context.value(arg));
  582. const argument = new pytorch.Argument(name, list);
  583. this.inputs.push(argument);
  584. } else if (Array.isArray(arg)) {
  585. const list = arg.map((arg) => arg instanceof torch.fx.node.Node ? context.value(arg) : arg);
  586. const argument = new pytorch.Argument(name, list, type || 'attribute');
  587. this.inputs.push(argument);
  588. } else if (arg instanceof torch.dtype || arg instanceof torch.device || arg instanceof torch.layout || arg instanceof torch.memory_format) {
  589. const argument = new pytorch.Argument(name, arg.toString(), type || 'attribute');
  590. this.inputs.push(argument);
  591. } else {
  592. const primitive = typeof arg === 'number' || typeof arg === 'boolean' || typeof arg === 'string' || arg === null;
  593. type = type === 'tensor' && primitive ? null : type;
  594. const argument = new pytorch.Argument(name, arg, type || 'attribute');
  595. this.inputs.push(argument);
  596. }
  597. }
  598. let outputs = [obj];
  599. if (obj.users.size > 1) {
  600. const users = Array.from(obj.users.keys());
  601. if (users.every((user) => user.op === 'call_function' && user.target.__module__ === 'operator' && user.target.__name__ === 'getitem')) {
  602. outputs = new Array(obj.users.size);
  603. for (const user of users) {
  604. const [, index] = user.args;
  605. outputs[index] = user;
  606. }
  607. }
  608. }
  609. for (let i = 0; i < outputs.length; i++) {
  610. const node = outputs[i];
  611. const value = context.value(node);
  612. const name = schema && schema.returns && schema.returns[i] ? schema.returns[i].name || 'output' : 'output';
  613. const argument = new pytorch.Argument(name, [value]);
  614. this.outputs.push(argument);
  615. }
  616. for (const [name, value] of obj.meta) {
  617. if (name === 'val' || name === 'torch_fn' ||
  618. (Array.isArray(value) && value.length === 0) ||
  619. (value instanceof Map && value.size === 0)) {
  620. continue;
  621. }
  622. if (typeof value === 'string') {
  623. const argument = new pytorch.Argument(name, value, 'string');
  624. this.metadata.push(argument);
  625. } else if (Array.isArray(value) && value.every((item) => typeof item === 'string')) {
  626. const argument = new pytorch.Argument(name, value, 'string[]');
  627. this.metadata.push(argument);
  628. } else if (value instanceof Map && value.size > 0) {
  629. // const argument = new pytorch.Argument(name, Object.fromEntries(Array.from(value)));
  630. // this.metadata.push(argument);
  631. } else {
  632. // const argument = new pytorch.Argument(name, value);
  633. // this.metadata.push(argument);
  634. }
  635. }
  636. } else if (obj.op === 'placeholder') {
  637. this.type = { name: obj.op };
  638. {
  639. const value = context.value(obj);
  640. const argument = new pytorch.Argument('value', [value]);
  641. this.inputs.push(argument);
  642. }
  643. {
  644. const node = new torch.fx.node.Node(null, obj.name);
  645. node.meta = obj.meta;
  646. const value = context.value(node);
  647. const argument = new pytorch.Argument('value', [value]);
  648. this.outputs.push(argument);
  649. }
  650. } else if (obj.op === 'get_attr') {
  651. this.type = { name: obj.op };
  652. const subgraph = context.function(obj);
  653. if (subgraph) {
  654. this.inputs.push(new pytorch.Argument('name', subgraph, 'function'));
  655. } else {
  656. this.inputs.push(new pytorch.Argument('name', obj.target, 'string'));
  657. }
  658. const value = context.value(obj);
  659. this.outputs.push(new pytorch.Argument('value', [value]));
  660. } else if (obj.op === 'root') {
  661. this.type = { name: obj.op };
  662. } else {
  663. throw new pytorch.Error(`Unsupported node operation '${obj.op}'.`);
  664. }
  665. } else {
  666. if (torch && obj instanceof torch.ScriptObject) {
  667. type = obj._type().qualified_name();
  668. obj = obj._ivalue;
  669. } else if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
  670. type = obj._c._type();
  671. const target = {
  672. _modules: obj._modules,
  673. _parameters: obj._parameters,
  674. _buffers: obj._buffers,
  675. };
  676. for (let i = 0; i < type.numAttributes(); i++) {
  677. if (!type.is_parameter(i) && !type.is_buffer(i) && !type.getAttribute(i).is_module()) {
  678. const k = type.getAttributeName(i);
  679. target[k] = obj.__getattr__(k);
  680. }
  681. }
  682. type = obj._c.qualified_name;
  683. obj = target;
  684. }
  685. if (!type) {
  686. if (torch && obj instanceof torch.jit._script.RecursiveScriptModule && obj._c && obj._c.qualified_name) {
  687. type = obj._c.qualified_name;
  688. } else if (pytorch.Utility.isInstance(obj, 'builtins.function')) {
  689. type = `${obj.__module__}.${obj.__name__}`;
  690. obj = {};
  691. } else if (obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
  692. type = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  693. } else {
  694. type = 'builtins.object';
  695. }
  696. }
  697. if (type instanceof nnapi.Graph) {
  698. this.type = type;
  699. } else {
  700. const key = type.startsWith('__torch__.') ? type.substring(10) : type;
  701. const value = metadata.type(key);
  702. this.type = value ? { ...value } : { name: type };
  703. this.type.identifier = type;
  704. }
  705. stack = stack || new Set();
  706. const weights = pytorch.Utility.weights(obj);
  707. if (weights) {
  708. const type = this.type.name;
  709. this.type = new pytorch.Graph(execution, metadata, 'weights', '', weights);
  710. this.type.name = type;
  711. } else if (obj && pytorch.Utility.isInstance(obj, 'fastai.data.core.DataLoaders')) {
  712. // continue
  713. } else if (obj && pytorch.Utility.isInstance(obj, '__torch__.torch.classes._nnapi.Compilation')) {
  714. // continue
  715. } else if (obj && type === 'builtins.bytearray') {
  716. const argument = new pytorch.Argument('value', Array.from(obj), 'byte[]');
  717. this.inputs.push(argument);
  718. } else if (obj) {
  719. const inputs = new Map(Array.isArray(this.type.inputs) ? this.type.inputs.map((input) => [input.name, input]) : []);
  720. const list = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  721. for (const [name, value] of list) {
  722. if (name === '__class__' || name === '__name__') {
  723. continue;
  724. } else if (pytorch.Utility.isInstance(value, 'collections.OrderedDict') && value instanceof Map && value.size === 0) {
  725. continue;
  726. } else if (pytorch.Utility.isInstance(value, 'builtins.set') && value instanceof Set && value.size === 0) {
  727. continue;
  728. } else if (pytorch.Utility.isInstance(value, 'builtins.list') && Array.isArray(value) && value.length === 0) {
  729. continue;
  730. } else if (pytorch.Utility.isInstance(value, 'torch.Size') && Array.isArray(value) && value.length === 0) {
  731. continue;
  732. }
  733. let parameters = null;
  734. if ((name === '_parameters' || name === '_buffers') && value instanceof Map) {
  735. parameters = value;
  736. } else if (pytorch.Utility.isTensor(value) || (Array.isArray(value) && value.every((tensor) => pytorch.Utility.isTensor(tensor)))) {
  737. parameters = new Map([[name, value]]);
  738. }
  739. if (parameters) {
  740. for (const [name, value] of parameters) {
  741. const list = Array.isArray(value) ? value.map((item) => pytorch.Utility.toTensor(item)) : [pytorch.Utility.toTensor(value)];
  742. const visible = inputs.has(name) ? inputs.get(name).visible || true : true;
  743. const args = list.filter((value) => value !== null && !value.__origin__).map((value) => {
  744. const name = value && value.name ? value.name : '';
  745. const identifier = list.length === 1 && value && value.__name__ ? value.__name__ : name;
  746. let tensor = null;
  747. if (initializers && initializers.has(value)) {
  748. tensor = initializers.get(value);
  749. } else {
  750. value = value.__source__ ? value.__source__ : value;
  751. tensor = value ? new pytorch.Tensor(context, identifier, value) : null;
  752. }
  753. return new pytorch.Value(identifier, null, null, tensor);
  754. });
  755. const argument = new pytorch.Argument(name, args, null, visible);
  756. this.inputs.push(argument);
  757. if (value && value.__variable__) {
  758. const argument = new pytorch.Argument(name, [context.values.map(value.__variable__)]);
  759. this.outputs.push(argument);
  760. }
  761. }
  762. continue;
  763. }
  764. if (pytorch.Utility.isTensor(value)) {
  765. const tensor = new pytorch.Tensor(context, '', value);
  766. const argument = new pytorch.Argument(name, tensor, 'tensor');
  767. this.inputs.push(argument);
  768. } else if (value && pytorch.Utility.isInstance(value, 'torch.dtype')) {
  769. const node = new pytorch.Node(execution, metadata, null, value.toString(), {}, null, context);
  770. const argument = new pytorch.Argument(name, node, 'object');
  771. this.inputs.push(argument);
  772. } else if (Array.isArray(value) && value.some((value) => pytorch.Utility.isTensor(value)) && value.every((value) => pytorch.Utility.isTensor(value) || value === null)) {
  773. const tensors = value.map((value) => value === null ? value : new pytorch.Tensor(context, '', value));
  774. const argument = new pytorch.Argument(name, tensors, 'tensor[]');
  775. this.inputs.push(argument);
  776. } else if (pytorch.Utility.isInstance(value, 'numpy.ndarray') || pytorch.Utility.isInstance(value, 'numpy.matrix')) {
  777. const tensor = new numpy.Tensor(value);
  778. const argument = new pytorch.Argument(name, tensor, 'tensor');
  779. this.inputs.push(argument);
  780. } else if (Array.isArray(value) && value.every((value) => typeof value === 'string')) {
  781. const argument = new pytorch.Argument(name, value, 'string[]');
  782. this.inputs.push(argument);
  783. } else if (Array.isArray(value) && value.every((value) => typeof value === 'number')) {
  784. const argument = new pytorch.Argument(name, value, 'attribute');
  785. this.inputs.push(argument);
  786. } else if (name === '_modules' && pytorch.Utility.isInstance(value, 'collections.OrderedDict') &&
  787. value instanceof Map && Array.from(value).every(([, value]) => value === null || value.__class__)) {
  788. const list = Array.from(value).filter(([, value]) => !stack.has(value)).map(([name, obj]) => {
  789. stack.add(value);
  790. const type = obj === null ? 'builtins.NoneType' : `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  791. const node = new pytorch.Node(execution, metadata, this.name ? `${this.name}.${name}` : name, type, obj, initializers, context, stack);
  792. stack.delete(value);
  793. return node;
  794. });
  795. const argument = new pytorch.Argument(name, list, 'object[]');
  796. this.inputs.push(argument);
  797. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => Array.isArray(obj) && obj.every((item) => typeof item === 'string' || typeof item === 'number'))) {
  798. const argument = new pytorch.Argument(name, value, 'attribute');
  799. this.inputs.push(argument);
  800. } else if (value && Array.isArray(value) && value.length > 0 && value.every((obj) => obj && (obj.__class__ || obj === Object(obj)))) {
  801. const list = value.filter((value) => !stack.has(value));
  802. const nodes = list.map((value) => {
  803. stack.add(value);
  804. const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
  805. stack.delete(value);
  806. return node;
  807. });
  808. const argument = new pytorch.Argument(name, nodes, 'object[]');
  809. this.inputs.push(argument);
  810. } else if (value && (value.__class__ || typeof value === 'object') && !stack.has(value)) {
  811. stack.add(value);
  812. const node = new pytorch.Node(execution, metadata, null, null, value, initializers, context, stack);
  813. stack.delete(value);
  814. const visible = name !== '_metadata' || !pytorch.Utility.isMetadataObject(value);
  815. const argument = new pytorch.Argument(name, node, 'object', visible);
  816. this.inputs.push(argument);
  817. } else {
  818. let schema = metadata.attribute(this.type.identifier, name);
  819. schema = name === 'training' ? { type: 'boolean', visible: false } : schema;
  820. let visible = true;
  821. let obj = value;
  822. const type = schema && schema.type ? schema.type : 'attribute';
  823. if (schema) {
  824. if (schema.visible === false) {
  825. visible = false;
  826. } else if (schema.default !== undefined) {
  827. if (Array.isArray(obj)) {
  828. if (Array.isArray(schema.default)) {
  829. visible = obj.length !== schema.default || !obj.every((item, index) => item === schema.default[index]);
  830. } else {
  831. visible = !obj.every((item) => item === schema.default);
  832. }
  833. } else {
  834. visible = obj !== schema.default;
  835. }
  836. }
  837. }
  838. if (Array.isArray(obj) && obj.length > 0 && obj.every((obj) => obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__module__.startsWith('torch.nn'))) {
  839. obj = '?';
  840. }
  841. const argument = new pytorch.Argument(name, obj, type, visible);
  842. this.inputs.push(argument);
  843. }
  844. }
  845. }
  846. }
  847. }
  848. };
  849. pytorch.Tensor = class {
  850. constructor(context, name, tensor) {
  851. this.name = name || '';
  852. this.attributes = [];
  853. tensor = tensor.data ? tensor.data : tensor;
  854. const storage = tensor.storage();
  855. this.type = context.type(tensor);
  856. const layout = this.type.layout;
  857. const size = this.type.shape.dimensions || [];
  858. if (layout) {
  859. this.indices = new pytorch.Tensor(context, '', tensor.indices);
  860. this._values = new pytorch.Tensor(context, '', tensor.values);
  861. } else {
  862. this.encoding = '<';
  863. this.indices = null;
  864. this.stride = tensor.stride();
  865. const stride = this.stride;
  866. const offset = tensor.storage_offset();
  867. if (storage) {
  868. this._data = storage.data;
  869. let length = 0;
  870. if (!Array.isArray(stride)) {
  871. length = storage.size();
  872. } else if (size.every((v) => v !== 0)) {
  873. length = size.reduce((a, v, i) => a + stride[i] * (v - 1), 1);
  874. }
  875. if (storage && typeof storage.size === 'function') {
  876. if (offset !== 0 || length !== storage.size()) {
  877. const itemsize = storage.dtype.itemsize();
  878. this._offset = itemsize * offset;
  879. this._length = itemsize * length;
  880. }
  881. }
  882. }
  883. }
  884. const type = tensor.__class__ || {};
  885. if (type.tensor_attribute_names) {
  886. for (const name of type.tensor_attribute_names) {
  887. let value = tensor[name];
  888. if (value !== undefined) {
  889. if (value && typeof value.__reduce__ === 'function') {
  890. value = value.__reduce__();
  891. }
  892. const attribute = new pytorch.Argument(name, value, 'attribute');
  893. this.attributes.push(attribute);
  894. }
  895. }
  896. }
  897. if (type.tensor_data_names) {
  898. for (const name of type.tensor_data_names) {
  899. const value = tensor[name];
  900. if (value !== undefined && pytorch.Utility.isTensor(value)) {
  901. const attribute = new pytorch.Argument(name, new pytorch.Tensor(context, name, value), 'tensor');
  902. this.attributes.push(attribute);
  903. }
  904. }
  905. }
  906. }
  907. get values() {
  908. const type = this.type.layout;
  909. if (type && type.startsWith('sparse.')) {
  910. return this._values;
  911. }
  912. if (this._data instanceof Uint8Array) {
  913. return this._data;
  914. }
  915. if (this._data && this._offset !== undefined) {
  916. const stream = this._data;
  917. const position = stream.position;
  918. stream.seek(this._offset);
  919. const values = stream.peek(this._length);
  920. stream.seek(position);
  921. return values;
  922. }
  923. if (this._data) {
  924. return this._data.peek();
  925. }
  926. return null;
  927. }
  928. };
  929. pytorch.TensorType = class {
  930. constructor(dataType, shape, layout) {
  931. this.dataType = dataType;
  932. this.shape = shape;
  933. this.layout = layout;
  934. }
  935. toString() {
  936. return this.dataType + this.shape.toString();
  937. }
  938. };
  939. pytorch.TensorShape = class {
  940. constructor(dimensions = []) {
  941. this.dimensions = dimensions;
  942. }
  943. toString() {
  944. if (this.dimensions && this.dimensions.length > 0) {
  945. return `[${this.dimensions.map((dimension) => dimension.toString()).join(',')}]`;
  946. }
  947. return '';
  948. }
  949. };
  950. pytorch.Context = class {
  951. constructor(execution, metadata) {
  952. this.execution = execution;
  953. this.torch = execution ? execution.__import__('torch') : null;
  954. this.metadata = metadata;
  955. this.values = new Map();
  956. this.modules = new Map();
  957. }
  958. type(tensor) {
  959. let dataType = tensor.dtype.__reduce__();
  960. switch (dataType) {
  961. case 'float8_e5m2': dataType = 'float8e5m2'; break;
  962. case 'float8_e5m2fnuz': dataType = 'float8e5m2fnuz'; break;
  963. case 'float8_e4m3fn': dataType = 'float8e4m3fn'; break;
  964. case 'float8_e4m3fnuz': dataType = 'float8e4m3fnuz'; break;
  965. case 'float8_e8m0fnu': dataType = 'float8e8m0fnu'; break;
  966. case 'float4_e2m1fn_x2': dataType = 'float4e2m1fnx2'; break;
  967. default: break;
  968. }
  969. const size = tensor.size ? tensor.size() : tensor.shape;
  970. const shape = new pytorch.TensorShape(size || []);
  971. const layout = tensor.layout ? tensor.layout.__str__() : null;
  972. if (layout && layout.startsWith('torch.sparse_')) {
  973. return new pytorch.TensorType(dataType, shape, layout.split('.').pop().replace('_', '.'));
  974. }
  975. return new pytorch.TensorType(dataType, shape);
  976. }
  977. value(obj) {
  978. const torch = this.torch;
  979. if (obj instanceof torch.fx.node.Node) {
  980. if (!this.values.has(obj)) {
  981. let type = null;
  982. const val = obj.meta ? obj.meta.get('val') : null;
  983. if (val && val.dtype) {
  984. type = this.type(val);
  985. }
  986. const value = new pytorch.Value(obj.name, type);
  987. this.values.set(obj, value);
  988. }
  989. return this.values.get(obj);
  990. }
  991. return null;
  992. }
  993. function(obj) {
  994. const torch = this.torch;
  995. if (obj instanceof torch.fx.node.Node) {
  996. let subgraph = this.modules.get(obj);
  997. if (subgraph) {
  998. if (subgraph instanceof pytorch.Graph === false) {
  999. subgraph = new pytorch.Graph(this.execution, this.metadata, 'function', obj.target, subgraph);
  1000. this.modules.set(obj, subgraph);
  1001. }
  1002. return subgraph;
  1003. }
  1004. }
  1005. return null;
  1006. }
  1007. graph(target, module, inputs) {
  1008. const graph = module.graph;
  1009. if (module.named_modules) {
  1010. const modules = module.named_modules();
  1011. for (const obj of graph.nodes) {
  1012. if (obj.op === 'get_attr') {
  1013. const submodule = modules.get(obj.target);
  1014. if (submodule && submodule.graph) {
  1015. this.modules.set(obj, submodule);
  1016. }
  1017. }
  1018. }
  1019. }
  1020. let controlDependency = null;
  1021. for (const obj of graph.nodes) {
  1022. if (obj.op === 'placeholder') {
  1023. if (inputs) {
  1024. const value = this.value(obj);
  1025. const argument = new pytorch.Argument(obj.name, [value]);
  1026. target.inputs.push(argument);
  1027. }
  1028. continue;
  1029. }
  1030. if (obj.op === 'call_function') {
  1031. if (obj.target.__module__ === 'operator' && obj.target.__name__ === 'getitem') {
  1032. continue;
  1033. }
  1034. }
  1035. if (obj.op === 'get_attr') {
  1036. if (this.modules.has(obj) && obj.users.size === 1) {
  1037. continue;
  1038. }
  1039. }
  1040. if (obj.op === 'output') {
  1041. for (const output of obj.args) {
  1042. if (output === null || output === undefined) {
  1043. continue;
  1044. }
  1045. if (output.op === 'call_function' && output.target.__module__ === 'operator' && output.target.__name__ === 'getitem') {
  1046. continue;
  1047. }
  1048. const value = this.value(output);
  1049. const argument = new pytorch.Argument(output.name, [value]);
  1050. target.outputs.push(argument);
  1051. }
  1052. continue;
  1053. }
  1054. const node = new pytorch.Node(this.execution, this.metadata, obj.name, null, obj, null, this);
  1055. target.nodes.push(node);
  1056. if (controlDependency) {
  1057. node.controlDependencies = node.controlDependencies || [];
  1058. node.controlDependencies.push(controlDependency);
  1059. controlDependency = null;
  1060. }
  1061. if (obj.op === 'call_function' && obj.users.size === 0) {
  1062. controlDependency = node.outputs[0].value[0];
  1063. }
  1064. }
  1065. }
  1066. };
  1067. pytorch.Reader = class {
  1068. static async open(context) {
  1069. const types = [
  1070. pytorch.Reader.Zip,
  1071. pytorch.Reader.Pickle,
  1072. pytorch.Reader.Tar,
  1073. pytorch.Reader.data_pkl,
  1074. pytorch.Reader.torch_utils,
  1075. pytorch.Reader.Mobile,
  1076. pytorch.Reader.ModelJson,
  1077. pytorch.Reader.IR,
  1078. pytorch.Reader.Index,
  1079. pytorch.Reader.ExportedProgram
  1080. ];
  1081. for (const type of types) {
  1082. // eslint-disable-next-line no-await-in-loop
  1083. const reader = await type.open(context);
  1084. if (reader) {
  1085. return reader;
  1086. }
  1087. }
  1088. return null;
  1089. }
  1090. constructor() {
  1091. this._events = [];
  1092. }
  1093. async read() {
  1094. }
  1095. on(event, callback) {
  1096. this._events.push([event, callback]);
  1097. }
  1098. };
  1099. pytorch.Reader.Tar = class extends pytorch.Reader {
  1100. static async open(context) {
  1101. const entries = await context.peek('tar');
  1102. if (entries instanceof Map && entries.has('pickle')) {
  1103. return new pytorch.Reader.Tar(entries);
  1104. }
  1105. return null;
  1106. }
  1107. constructor(entries) {
  1108. super();
  1109. this.type = 'pytorch.tar';
  1110. this.entries = entries;
  1111. }
  1112. async read() {
  1113. this.format = 'PyTorch v0.1.1';
  1114. const execution = new python.Execution();
  1115. for (const event of this._events) {
  1116. execution.on(event[0], event[1]);
  1117. }
  1118. const torch = execution.__import__('torch');
  1119. this.module = torch.load(this.entries);
  1120. delete this.entries;
  1121. }
  1122. };
  1123. pytorch.Reader.Pickle = class extends pytorch.Reader {
  1124. static async open(context) {
  1125. const stream = context.stream;
  1126. const signature = [0x80, undefined, 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19];
  1127. if (stream && signature.length <= stream.length && stream.peek(signature.length).every((value, index) => signature[index] === undefined || signature[index] === value)) {
  1128. return new pytorch.Reader.Pickle(stream);
  1129. }
  1130. return null;
  1131. }
  1132. constructor(stream) {
  1133. super();
  1134. this.type = 'pytorch.pickle';
  1135. this.stream = stream;
  1136. }
  1137. async read() {
  1138. this.format = 'PyTorch v0.1.10';
  1139. const data = this.stream.length < 0x7ffff000 ? this.stream.peek() : this.stream;
  1140. delete this.stream;
  1141. const execution = new python.Execution();
  1142. for (const event of this._events) {
  1143. execution.on(event[0], event[1]);
  1144. }
  1145. const torch = execution.__import__('torch');
  1146. this.module = torch.load(data);
  1147. }
  1148. };
  1149. pytorch.Reader.data_pkl = class extends pytorch.Reader {
  1150. static async open(context) {
  1151. const obj = await context.peek('pkl');
  1152. if (obj) {
  1153. if (obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__) {
  1154. const name = `${obj.__class__.__module__}.${obj.__class__.__name__}`;
  1155. if (name.startsWith('__torch__.')) {
  1156. return new pytorch.Reader.data_pkl('', obj);
  1157. }
  1158. }
  1159. if (pytorch.Utility.isTensor(obj)) {
  1160. return new pytorch.Reader.data_pkl('tensor', obj);
  1161. }
  1162. if (Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor))) {
  1163. return new pytorch.Reader.data_pkl('tensor', obj);
  1164. }
  1165. if (obj instanceof Map) {
  1166. const entries = Array.from(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
  1167. if (entries.length > 0) {
  1168. return new pytorch.Reader.data_pkl('tensor', obj);
  1169. }
  1170. } else if (!Array.isArray(obj)) {
  1171. const entries = Object.entries(obj).filter(([, value]) => pytorch.Utility.isTensor(value));
  1172. if (entries.length > 0) {
  1173. return new pytorch.Reader.data_pkl('tensor', obj);
  1174. }
  1175. }
  1176. for (const key of ['', 'model', 'net']) {
  1177. const module = key === '' ? obj : obj[key];
  1178. if (module && module._modules && pytorch.Utility.isInstance(module._modules, 'collections.OrderedDict')) {
  1179. return new pytorch.Reader.data_pkl('module', module);
  1180. }
  1181. }
  1182. }
  1183. return null;
  1184. }
  1185. constructor(type, module) {
  1186. super();
  1187. this.type = 'pytorch.data.pkl';
  1188. this.format = 'PyTorch Pickle';
  1189. this.module = module;
  1190. }
  1191. async read() {
  1192. }
  1193. };
  1194. pytorch.Reader.torch_utils = class extends pytorch.Reader {
  1195. static async open(context) {
  1196. const stream = context.stream;
  1197. if (stream && stream.length > 1) {
  1198. const buffer = stream.peek(Math.min(1024, stream.length));
  1199. if (buffer[0] === 0x80) {
  1200. const content = String.fromCharCode.apply(null, buffer);
  1201. if (content.indexOf('torch_utils') !== -1) {
  1202. const obj = await context.peek('pkl');
  1203. if (obj && Object.entries(obj).some(([, value]) => pytorch.Utility.isInstance(value, 'torch.nn.modules.module.Module'))) {
  1204. return new pytorch.Reader.torch_utils(obj);
  1205. }
  1206. }
  1207. }
  1208. }
  1209. return null;
  1210. }
  1211. constructor(obj) {
  1212. super();
  1213. this.type = 'pytorch.torch_utils';
  1214. this.obj = obj;
  1215. }
  1216. async read() {
  1217. this.format = 'PyTorch torch_utils';
  1218. this.module = this.obj;
  1219. delete this.obj;
  1220. }
  1221. };
  1222. pytorch.Reader.Mobile = class extends pytorch.Reader {
  1223. static async open(context) {
  1224. const reader = await context.peek('flatbuffers.binary');
  1225. if (reader && reader.identifier === 'PTMF') {
  1226. return new pytorch.Reader.Mobile(context);
  1227. }
  1228. return null;
  1229. }
  1230. constructor(context) {
  1231. super();
  1232. this.type = 'pytorch.mobile';
  1233. this.context = context;
  1234. }
  1235. async read(metadata) {
  1236. const execution = new pytorch.Execution(null, metadata);
  1237. for (const event of this._events) {
  1238. execution.on(event[0], event[1]);
  1239. }
  1240. const stream = this.context.stream;
  1241. const torch = execution.__import__('torch');
  1242. torch.mobile = await this.context.require('./pytorch-schema');
  1243. torch.mobile = torch.mobile.torch.jit.mobile;
  1244. this.module = torch.jit.jit_module_from_flatbuffer(stream);
  1245. const version = this.module._c._bytecode_version.toString();
  1246. this.format = pytorch.Utility.format('PyTorch Mobile', version);
  1247. delete this.context;
  1248. }
  1249. };
  1250. pytorch.Reader.Zip = class extends pytorch.Reader {
  1251. static async open(context) {
  1252. const entries = await context.peek('zip');
  1253. if (entries instanceof Map && entries.size > 0) {
  1254. let prefix = 0;
  1255. const paths = Array.from(entries.keys()).map((path) => path.replace(/\\/g, '/').split('/').reverse());
  1256. for (let set = new Set(); set && paths.length > 0;) {
  1257. set = new Set(paths.map((path) => path.length > 1 ? path.pop() : null));
  1258. set = set.size > 1 || set.keys().next().value === null ? null : set;
  1259. prefix += set ? set.keys().next().value.length + 1 : 0;
  1260. }
  1261. const records = new Map(Array.from(entries).map(([name, value]) => [name.substring(prefix), value]));
  1262. if (records.has('model.json')) {
  1263. return null;
  1264. }
  1265. if (records.has('data.pkl')) {
  1266. return new pytorch.Reader.Zip(entries);
  1267. }
  1268. if (records.has('.data/version') && !records.has('archive_format')) {
  1269. return new pytorch.Reader.Package(entries);
  1270. }
  1271. }
  1272. return null;
  1273. }
  1274. constructor(entries) {
  1275. super();
  1276. this.type = 'pytorch.zip';
  1277. // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/OVERVIEW.md
  1278. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
  1279. this._entries = entries;
  1280. }
  1281. async read(metadata) {
  1282. this.execution = new pytorch.Execution(null, metadata);
  1283. for (const event of this._events) {
  1284. this.execution.on(event[0], event[1]);
  1285. }
  1286. const torch = this.execution.__import__('torch');
  1287. const reader = new torch.PyTorchFileReader(this._entries);
  1288. let torchscript = reader.has_record('constants.pkl');
  1289. const version = reader.version();
  1290. if (torchscript) {
  1291. metadata.register(this.execution);
  1292. this.module = torch.jit.load(reader);
  1293. torchscript = this.module._c._has_method('forward');
  1294. if (torchscript) {
  1295. // console.log(this.module.graph.toString());
  1296. torch._C._jit_pass_inline(this.module.graph);
  1297. // console.log(this.module.graph.toString());
  1298. }
  1299. } else {
  1300. const records = reader.get_all_records().map((key) => [key, reader.get_record(key)]);
  1301. const entries = new Map(records);
  1302. this.module = torch.load(entries);
  1303. }
  1304. const name = torchscript ? 'TorchScript' : 'PyTorch';
  1305. this.format = pytorch.Utility.format(name, version);
  1306. delete this._model;
  1307. delete this._entries;
  1308. }
  1309. };
  1310. pytorch.Reader.ModelJson = class extends pytorch.Reader {
  1311. static async open(context) {
  1312. const identifier = context.identifier;
  1313. if (identifier === 'model.json') {
  1314. const model = await context.peek('json');
  1315. if (model && model.mainModule) {
  1316. const entries = new Map();
  1317. entries.set('model.json', context.stream);
  1318. return new pytorch.Reader.ModelJson(context, entries, model);
  1319. }
  1320. }
  1321. return null;
  1322. }
  1323. constructor(context, entries, model) {
  1324. super();
  1325. this.type = 'pytorch.model.json';
  1326. this._context = context;
  1327. this._entries = entries;
  1328. this._model = model;
  1329. }
  1330. async read(metadata) {
  1331. pytorch.proto = await this._context.require('./pytorch-proto');
  1332. const keys = [
  1333. 'attributes.pkl',
  1334. 'version',
  1335. ...this._model.tensors.filter((tensor) => tensor && tensor.data && tensor.data.key).map((tensor) => tensor.data.key)
  1336. ];
  1337. const walk = (module) => {
  1338. if (module.torchscriptArena && module.torchscriptArena.key) {
  1339. keys.push(module.torchscriptArena.key);
  1340. }
  1341. for (const submodule of module.submodules || []) {
  1342. walk(submodule);
  1343. }
  1344. };
  1345. walk(this._model.mainModule);
  1346. const values = await Promise.all(keys.map((name) => this._context.fetch(name).then((context) => context.stream).catch(() => null)));
  1347. for (let i = 0; i < keys.length; i++) {
  1348. if (values[i]) {
  1349. this._entries.set(keys[i], values[i]);
  1350. }
  1351. }
  1352. this.execution = new pytorch.Execution(null, metadata);
  1353. this.execution.proto = pytorch.proto;
  1354. for (const event of this._events) {
  1355. this.execution.on(event[0], event[1]);
  1356. }
  1357. const torch = this.execution.__import__('torch');
  1358. const reader = new torch.PyTorchFileReader(this._entries);
  1359. if (this._model && this._model.producerName) {
  1360. this.producer = this._model.producerName + (this._model.producerVersion ? ` v${this._model.producerVersion}` : '');
  1361. }
  1362. this.format = reader.has_record('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
  1363. metadata.register(this.execution);
  1364. this.module = torch.jit.load(reader);
  1365. if (this.module._c._has_method('forward')) {
  1366. // console.log(this.module.graph.toString());
  1367. torch._C._jit_pass_inline(this.module.graph);
  1368. // console.log(this.module.graph.toString());
  1369. }
  1370. delete this._context;
  1371. delete this._model;
  1372. delete this._entries;
  1373. }
  1374. };
  1375. pytorch.Reader.IR = class extends pytorch.Reader {
  1376. static async open(context) {
  1377. const reader = await context.read('text', 0x100);
  1378. if (reader && reader.length > 0) {
  1379. const line = reader.read('\n');
  1380. if (line.startsWith('graph(')) {
  1381. return new pytorch.Reader.IR(context);
  1382. }
  1383. }
  1384. return null;
  1385. }
  1386. constructor(context) {
  1387. super();
  1388. this.type = 'pytorch.ir';
  1389. this.context = context;
  1390. }
  1391. async read(metadata) {
  1392. this.format = 'TorchScript IR';
  1393. this.execution = new pytorch.Execution(null, metadata);
  1394. for (const event of this._events) {
  1395. this.execution.on(event[0], event[1]);
  1396. }
  1397. // this.execution.graph;
  1398. // context reader = await context.read('text', 0x100);
  1399. throw new pytorch.Error('TorchScript IR parser not implemented.');
  1400. }
  1401. };
  1402. pytorch.Reader.Index = class extends pytorch.Reader {
  1403. static async open(context) {
  1404. const obj = await context.peek('json');
  1405. if (obj && obj.weight_map) {
  1406. const entries = Object.entries(obj.weight_map);
  1407. if (entries.length > 0 && entries.every(([, value]) => typeof value === 'string' && value.endsWith('.bin'))) {
  1408. return new pytorch.Reader.Index(context, entries);
  1409. }
  1410. }
  1411. return null;
  1412. }
  1413. constructor(context, entries) {
  1414. super();
  1415. this.type = 'pytorch.index';
  1416. this.context = context;
  1417. this._entries = entries;
  1418. }
  1419. async read(metadata) {
  1420. this.format = 'PyTorch';
  1421. const weight_map = new Map(this._entries);
  1422. const keys = new Set(weight_map.keys());
  1423. const files = Array.from(new Set(weight_map.values()));
  1424. const contexts = await Promise.all(files.map((name) => this.context.fetch(name)));
  1425. this.execution = new pytorch.Execution(null, metadata);
  1426. for (const event of this._events) {
  1427. this.execution.on(event[0], event[1]);
  1428. }
  1429. const torch = this.execution.__import__('torch');
  1430. const archives = await Promise.all(contexts.map((context) => context.peek('zip')));
  1431. const formats = new Set(archives.map((entries) => {
  1432. const reader = new torch.PyTorchFileReader(entries);
  1433. const version = reader.version();
  1434. return pytorch.Utility.format('PyTorch', version);
  1435. }));
  1436. if (formats.size === 1) {
  1437. this.format = formats.values().next().value;
  1438. }
  1439. const shards = archives.map((entries) => {
  1440. return torch.load(entries);
  1441. });
  1442. const entries = new Map();
  1443. for (const shard of shards) {
  1444. for (const [key, value] of Array.from(shard)) {
  1445. if (keys.has(key)) {
  1446. entries.set(key, value);
  1447. }
  1448. }
  1449. }
  1450. this.module = entries;
  1451. delete this.context;
  1452. delete this._entries;
  1453. }
  1454. };
  1455. pytorch.Reader.ExportedProgram = class extends pytorch.Reader {
  1456. static async open(context) {
  1457. const program = await context.peek('json');
  1458. if (program && program.schema_version && program.graph_module) {
  1459. return new pytorch.Reader.ExportedProgram(context, program);
  1460. }
  1461. if (context.identifier === 'archive_format' && context.stream && context.stream.length < 10) {
  1462. const buffer = context.stream.peek();
  1463. const archive_format = String.fromCharCode.apply(null, buffer);
  1464. if (archive_format === 'pt2') {
  1465. return new pytorch.Reader.ExportedProgram(context, null, context);
  1466. }
  1467. }
  1468. return null;
  1469. }
  1470. constructor(context, exported_program, archive_format) {
  1471. super();
  1472. this.type = 'pytorch.export';
  1473. this.context = context;
  1474. this.archive_format = archive_format;
  1475. this.exported_program = exported_program;
  1476. }
  1477. async read(metadata) {
  1478. this.format = 'PyTorch Export';
  1479. const f = new Map();
  1480. const exported_programs = new Map();
  1481. if (this.archive_format) {
  1482. for (const name of this.context.container.entries.keys()) {
  1483. const match = name.match(/^models\/([^/]+)\.json$/);
  1484. if (match) {
  1485. const [, model_name] = match;
  1486. /* eslint-disable no-await-in-loop */
  1487. const model = await this.context.fetch(`models/${model_name}.json`);
  1488. const exported_program = await model.read('json');
  1489. exported_programs.set(model_name, exported_program);
  1490. f.set(`models/${model_name}.json`, exported_program);
  1491. const sample_inputs = await this._fetch(`data/sample_inputs/${model_name}.pt`, 'zip');
  1492. f.set(`data/sample_inputs/${model_name}.pt`, sample_inputs);
  1493. const weights_config = await this._fetch(`data/weights/${model_name}_weights_config.json`, 'json');
  1494. if (weights_config) {
  1495. f.set(`data/weights/${model_name}_weights_config.json`, weights_config);
  1496. for (const payload_meta of Object.values(weights_config.config)) {
  1497. const type = payload_meta.use_pickle ? 'zip' : 'binary';
  1498. const weight_data = await this._fetch(`data/weights/${payload_meta.path_name}`, type);
  1499. if (weight_data) {
  1500. f.set(`data/weights/${payload_meta.path_name}`, weight_data);
  1501. }
  1502. }
  1503. } else {
  1504. const weights = await this._fetch(`data/weights/${model_name}.pt`, 'zip');
  1505. f.set(`data/weights/${model_name}.pt`, weights);
  1506. }
  1507. const constants_config = await this._fetch(`data/constants/${model_name}_constants_config.json`, 'json');
  1508. if (constants_config) {
  1509. f.set(`data/constants/${model_name}_constants_config.json`, constants_config);
  1510. for (const payload_meta of Object.values(constants_config.config)) {
  1511. // eslint-enable no-await-in-loop
  1512. const type = payload_meta.use_pickle ? 'zip' : 'binary';
  1513. const constant_data = await this._fetch(`data/constants/${payload_meta.path_name}`, type);
  1514. if (constant_data) {
  1515. f.set(`data/constants/${payload_meta.path_name}`, constant_data);
  1516. }
  1517. }
  1518. } else {
  1519. const constants = await this._fetch(`data/constants/${model_name}.pt`);
  1520. f.set(`data/constants/${model_name}.pt`, constants);
  1521. }
  1522. /* eslint-enable no-await-in-loop */
  1523. }
  1524. }
  1525. const byteorder = await this._fetch('byteorder', 'text') || 'little';
  1526. f.set('byteorder', byteorder);
  1527. } else {
  1528. this.version = await this._fetch('version', 'text') || '';
  1529. this.version = this.version.split('\n').shift().trim();
  1530. const weights = await this._fetch('serialized_state_dict.pt', 'zip') || await this._fetch('serialized_state_dict.json', 'zip');
  1531. const constants = await this._fetch('serialized_constants.pt', 'zip') || await this._fetch('serialized_constants.json', 'zip');
  1532. const sample_inputs = await this._fetch('serialized_example_inputs.pt', 'zip');
  1533. f.set('models/model.json', this.exported_program);
  1534. f.set('data/weights/model.pt', weights);
  1535. f.set('data/constants/model.pt', constants);
  1536. f.set('data/sample_inputs/model.pt', sample_inputs);
  1537. exported_programs.set('', this.exported_program);
  1538. }
  1539. if (!this.version) {
  1540. const versions = new Set();
  1541. for (const exported_program of exported_programs.values()) {
  1542. const schema_version = exported_program.schema_version;
  1543. if (schema_version && schema_version.major && schema_version.minor) {
  1544. versions.add(`${schema_version.major}.${schema_version.minor}`);
  1545. }
  1546. }
  1547. if (versions.size === 1) {
  1548. this.version = versions.values().next().value;
  1549. }
  1550. }
  1551. this.format = this.version ? `${this.format} v${this.version}` : this.format;
  1552. this.execution = new python.Execution();
  1553. for (const event of this._events) {
  1554. this.execution.on(event[0], event[1]);
  1555. }
  1556. metadata.register(this.execution);
  1557. const torch = this.execution.__import__('torch');
  1558. for (const exported_program of exported_programs.values()) {
  1559. if (exported_program.graph_module.graph.constants) {
  1560. // eslint-disable-next-line no-await-in-loop
  1561. const zip = await import('./zip.js');
  1562. const constants = exported_program.graph_module.graph.constants;
  1563. for (const key of Object.keys(constants)) {
  1564. const value = constants[key];
  1565. const str = atob(value);
  1566. const buffer = new Uint8Array(str.length);
  1567. for (let i = 0; i < str.length; i++) {
  1568. buffer[i] = str.charCodeAt(i);
  1569. }
  1570. const archive = zip.Archive.open(buffer);
  1571. constants[key] = archive.entries;
  1572. }
  1573. }
  1574. }
  1575. delete this.exported_program;
  1576. delete this.context;
  1577. const pt2_contents = torch.export.pt2_archive._package.load_pt2(f);
  1578. this.modules = pt2_contents.exported_programs;
  1579. }
  1580. async _fetch(name, type) {
  1581. try {
  1582. const context = await this.context.fetch(name);
  1583. if (context) {
  1584. switch (type) {
  1585. case 'zip':
  1586. return await context.peek('zip');
  1587. case 'json':
  1588. return await context.read('json');
  1589. case 'text': {
  1590. const reader = await context.read('text');
  1591. if (reader) {
  1592. return reader.read();
  1593. }
  1594. break;
  1595. }
  1596. case 'binary': {
  1597. if (context && context.stream) {
  1598. return context.stream.peek();
  1599. }
  1600. break;
  1601. }
  1602. default: {
  1603. throw new pytorch.Error(`Unsupported context type '${type}.`);
  1604. }
  1605. }
  1606. }
  1607. } catch {
  1608. // continue regardless of error
  1609. }
  1610. return null;
  1611. }
  1612. };
  1613. pytorch.Execution = class extends python.Execution {
  1614. constructor(sources, metadata) {
  1615. super(sources);
  1616. this._metadata = metadata;
  1617. // eslint-disable-next-line consistent-this
  1618. const execution = this;
  1619. const torch = this.torch;
  1620. this.registerFunction('torch.jit.jit_module_from_flatbuffer', (f) => {
  1621. const cu = new torch.jit.CompilationUnit();
  1622. cu.execution = execution;
  1623. const stream = f;
  1624. const reader = flatbuffers.BinaryReader.open(stream);
  1625. const module = torch.mobile.serialization.Module.create(reader);
  1626. const loader = new torch._C.FlatBuffersLoader(cu);
  1627. const cpp_module = loader.parseModule(module);
  1628. // parse_and_initialize_jit_module
  1629. // const mobilem = parse_and_initialize_mobile_module_for_jit(data, jit_files, jit_constants);
  1630. // const m = jitModuleFromSourceAndConstants(mobilem._ivalue(), jit_files, jit_constants, mobilem.bytecode_version());
  1631. // throw new pytorch.Error('torch.jit.mobile.serialization.Module not supported.');
  1632. return torch.jit._script.wrap_cpp_module(cpp_module);
  1633. });
  1634. this.registerType('__torch__.torch.classes._nnapi.Compilation', class {
  1635. constructor() {
  1636. this.__hide__ = true;
  1637. }
  1638. __init__() {
  1639. }
  1640. init(serialized_model_tensor, parameter_buffers) {
  1641. this.serialized_model_tensor = serialized_model_tensor;
  1642. this.parameter_buffers = parameter_buffers;
  1643. const buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
  1644. /*
  1645. let buffers = [];
  1646. if (!pytorch.Utility.isInstance(parameter_buffers, 'torch.Value')) {
  1647. buffers = parameter_buffers.map((buffer) => buffer.__source__.storage());
  1648. }
  1649. */
  1650. const serialized_model = serialized_model_tensor.storage().data;
  1651. this.serialized_model = new nnapi.SerializedModel(serialized_model, buffers);
  1652. }
  1653. run(inputs, outputs) {
  1654. execution.variable(this.serialized_model_tensor);
  1655. this.serialized_model_tensor.__count__ = (this.serialized_model_tensor.__count__ || 0) + 1;
  1656. const type = new nnapi.Graph(this.serialized_model);
  1657. const node = execution.graph.create(type, 0);
  1658. execution.graph.insertNode(node);
  1659. for (const tensor of inputs) {
  1660. const value = execution.variable(tensor);
  1661. node.addInput(value);
  1662. }
  1663. for (const tensor of outputs) {
  1664. execution.variable(tensor, node);
  1665. }
  1666. }
  1667. });
  1668. this.registerType('__torch__.torch.classes.quantized.Conv2dPackedParamsBase', class {
  1669. __setstate__(state) {
  1670. if (state[0] !== '2') {
  1671. throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
  1672. }
  1673. const [/* pack_version */, tensors, opt_tensors] = state;
  1674. const packed_config = tensors[0].tolist();
  1675. this.weight = tensors[1];
  1676. this.bias = opt_tensors[0];
  1677. this.stride = [packed_config[1], packed_config[2]];
  1678. this.padding = [packed_config[3], packed_config[4]];
  1679. this.dilation = [packed_config[5], packed_config[6]];
  1680. this.output_padding = [packed_config[7], packed_config[8]];
  1681. this.groups = packed_config[9];
  1682. }
  1683. });
  1684. this.registerType('__torch__.torch.classes.quantized.Conv3dPackedParamsBase', class {
  1685. __setstate__(state) {
  1686. if (state[0] !== '2') {
  1687. throw new pytorch.Error(`Unsupported pack version '${state[0]}'.`);
  1688. }
  1689. const [/* pack_version */, tensors, opt_tensors] = state;
  1690. const packed_config = tensors[0].tolist();
  1691. this.weight = tensors[1];
  1692. this.bias = opt_tensors[0];
  1693. this.stride = [packed_config[1], packed_config[2]];
  1694. this.padding = [packed_config[3], packed_config[4]];
  1695. this.dilation = [packed_config[5], packed_config[6]];
  1696. this.output_padding = [packed_config[7], packed_config[8]];
  1697. this.groups = packed_config[9];
  1698. }
  1699. });
  1700. this.registerType('__torch__.torch.classes.quantized.LinearPackedParamsBase', class {
  1701. __setstate__(state) {
  1702. [this.weight, this.bias] = state;
  1703. }
  1704. });
  1705. this.registerType('__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', class {
  1706. __setstate__(state) {
  1707. [this.version, this.tensors, this.doubles, this.longs] = state;
  1708. }
  1709. });
  1710. this.registerType('__torch__.torch.classes.rnn.CellParamsBase', class {
  1711. __setstate__(state) {
  1712. [this.type, this.tensors, this.doubles, this.longs, this.packed_params] = state;
  1713. }
  1714. });
  1715. this.registerType('__torch__.torch.classes.xnnpack.Conv2dOpContext', class {
  1716. __setstate__(state) {
  1717. [this.weight, this.bias, this.stride, this.padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
  1718. }
  1719. });
  1720. this.registerType('__torch__.torch.classes.xnnpack.LinearOpContext', class {
  1721. __setstate__(state) {
  1722. [this.weight, this.bias, this.output_min, this.output_max] = state;
  1723. }
  1724. });
  1725. this.registerType('__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', class {
  1726. __setstate__(state) {
  1727. [this.weight, this.bias, this.stride, this.padding, this.output_padding, this.dilation, this.groups, this.output_min, this.output_max] = state;
  1728. }
  1729. });
  1730. this.registerType('__torch__.torch.classes.tensorrt.Engine', class {
  1731. __setstate__(state) {
  1732. [this.abi_target, this.name, this.device, this.engine, this.input_binding_names, this.output_binding_names, this.hw_compatible, this.serialized_metadata, this.target_platform] = state;
  1733. }
  1734. });
  1735. const custom_classes = [
  1736. { name: '__torch__.torch.classes._nnapi.Compilation', methods: [
  1737. '__init__(__torch__.torch.classes._nnapi.Compilation self) -> NoneType',
  1738. 'init(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers) -> NoneType',
  1739. 'init2(__torch__.torch.classes._nnapi.Compilation self, Tensor serialized_model_tensor, Tensor[] parameter_buffers, int compilation_preference, bool relax_f32_to_f16) -> NoneType',
  1740. 'run(__torch__.torch.classes._nnapi.Compilation self, Tensor[] inputs, Tensor[] outputs) -> NoneType'
  1741. ] },
  1742. { name: '__torch__.torch.classes.quantized.Conv2dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
  1743. { name: '__torch__.torch.classes.quantized.Conv3dPackedParamsBase', attributes: 'Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups', methods: ['unpack(__torch__.torch.classes.quantized.Conv3dPackedParamsBase self) -> ((Tensor, Tensor?))'] },
  1744. { name: '__torch__.torch.classes.quantized.LinearPackedParamsBase', attributes: 'Tensor weight, Tensor? bias' },
  1745. { name: '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase', attributes: 'int version, Tensor[] tensors, float[] doubles, int[] longs', methods: [] },
  1746. { name: '__torch__.torch.classes.rnn.CellParamsBase', attributes: 'str type, Tensor[] tensors, float[] doubles, int[] longs, __torch__.torch.classes.quantized.LinearPackedParamsBase[] packed_params' },
  1747. { name: '__torch__.torch.classes.xnnpack.Conv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
  1748. { name: '__torch__.torch.classes.xnnpack.LinearOpContext', attributes: 'Tensor weight, Tensor bias, int[] output_min, int[] output_max' },
  1749. { name: '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext', attributes: 'Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups, int[] output_min, int[] output_max' },
  1750. { name: '__torch__.torch.classes.tensorrt.Engine' }
  1751. ];
  1752. for (const known_type of custom_classes) {
  1753. const prefix = new torch._C.QualifiedName(known_type.name);
  1754. const type = torch.ClassType.create(known_type.name, this._compilation_unit, false);
  1755. for (const known_method of known_type.methods || []) {
  1756. const schema = new torch.FunctionSchema(known_method);
  1757. const name = new torch._C.QualifiedName(prefix, schema.name);
  1758. const fn = new torch._C.BuiltinOpFunction(name, schema);
  1759. type.addMethod(fn);
  1760. }
  1761. if (known_type.attributes) {
  1762. const schema = new torch.FunctionSchema(`(${known_type.attributes}) -> ()`);
  1763. for (const arg of schema.arguments) {
  1764. type.addAttribute(arg.name, arg.real_type);
  1765. }
  1766. }
  1767. torch._C.registerCustomClass(type);
  1768. }
  1769. }
  1770. call(target, name, args, keywords, context) {
  1771. const ast = this.ast;
  1772. const torch = this.torch;
  1773. if (target instanceof ast.Name && target.id === 'torch') {
  1774. const fn = torch.ops.aten.__getattr__(name);
  1775. if (fn) {
  1776. const evalArgs = args.map((arg) => this.expression(arg, context));
  1777. return fn.__call__(...evalArgs);
  1778. }
  1779. }
  1780. if (target instanceof ast.Attribute && target.value instanceof ast.Name && target.value.id === 'ops') {
  1781. const module = torch.ops[target.attr];
  1782. if (!module) {
  1783. throw new pytorch.Error(`Unknown torch.ops module '${target.attr}'.`);
  1784. }
  1785. const fn = module.__getattr__(name);
  1786. if (fn) {
  1787. const evalArgs = args.map((arg) => this.expression(arg, context));
  1788. return fn.__call__(...evalArgs);
  1789. }
  1790. }
  1791. return super.call(target, name, args, keywords, context);
  1792. }
  1793. invoke(target, args) {
  1794. if (target && Array.isArray(target.__bases__) && target.__bases__.length > 0 && target.__bases__[0] === this.enum.Enum) {
  1795. const instance = new target();
  1796. instance.value = args;
  1797. return instance;
  1798. }
  1799. return super.invoke(target, args);
  1800. }
  1801. base(expr, context) {
  1802. const ast = this.ast;
  1803. if (expr instanceof ast.Name) {
  1804. switch (expr.id) {
  1805. case 'Enum': return this.enum.Enum;
  1806. default: break;
  1807. }
  1808. }
  1809. return this.expression(expr, context);
  1810. }
  1811. };
  1812. pytorch.Reader.Package = class extends pytorch.Reader {
  1813. constructor(entries) {
  1814. super();
  1815. this.type = 'pytorch.package';
  1816. this.entries = entries;
  1817. }
  1818. async read(metadata) {
  1819. this.execution = new pytorch.Execution(null, metadata);
  1820. for (const event of this._events) {
  1821. this.execution.on(event[0], event[1]);
  1822. }
  1823. const torch = this.execution.__import__('torch');
  1824. const reader = new torch.PyTorchFileReader(this.entries);
  1825. const version = reader.version();
  1826. this.format = pytorch.Utility.format('PyTorch Package', version);
  1827. this.modules = new Map();
  1828. const records = reader.get_all_records().filter((name) => {
  1829. if (!name.startsWith('.data/') && !name.endsWith('.py')) {
  1830. const stream = reader.get_record(name);
  1831. if (stream && stream.length > 2) {
  1832. const signature = stream.peek(2);
  1833. if (signature[0] === 0x80 && signature[1] < 7) {
  1834. return true;
  1835. }
  1836. }
  1837. }
  1838. return false;
  1839. });
  1840. const entries = records.map((name) => {
  1841. const parts = name.split('/');
  1842. const resource = parts.pop();
  1843. const module = parts.join('.');
  1844. return [module, resource];
  1845. });
  1846. if (entries.length > 0) {
  1847. for (const name of reader.get_all_records()) {
  1848. if (!name.startsWith('.data/') && name.endsWith('.py')) {
  1849. const stream = reader.get_record(name);
  1850. const buffer = stream.peek();
  1851. this.execution.add(name, buffer);
  1852. }
  1853. }
  1854. metadata.register(this.execution);
  1855. const importer = new torch.package.PackageImporter(reader);
  1856. for (const entry of entries) {
  1857. const module = importer.load_pickle(entry[0], entry[1]);
  1858. const key = `${entry[0].replace(/\./, '/')}/${entry[1]}`;
  1859. this.modules.set(key, module);
  1860. }
  1861. }
  1862. delete this.entries;
  1863. }
  1864. };
  1865. pytorch.MemoryFormat = {
  1866. Contiguous: 0,
  1867. Preserve: 1,
  1868. ChannelsLast: 2,
  1869. ChannelsLast3d: 3
  1870. };
  1871. pytorch.Layout = {
  1872. Strided: 0,
  1873. Sparse: 1,
  1874. Mkldnn: 2
  1875. };
  1876. pytorch.Utility = class {
  1877. static isTensor(obj) {
  1878. const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
  1879. switch (name) {
  1880. case 'torch':
  1881. case 'torch.cuda':
  1882. return obj.__class__.__name__.endsWith('Tensor');
  1883. case 'torch.nn.parameter':
  1884. return obj.__class__.__name__ === 'Parameter';
  1885. default:
  1886. return false;
  1887. }
  1888. }
  1889. static toTensor(obj) {
  1890. const name = obj && obj.__class__ ? obj.__class__.__module__ : null;
  1891. switch (name) {
  1892. case 'torch':
  1893. case 'torch.cuda':
  1894. return obj.__class__.__name__.endsWith('Tensor') ? obj : null;
  1895. case 'torch.nn.parameter':
  1896. if (obj.__class__.__name__ === 'Parameter') {
  1897. const data = obj.data;
  1898. if (typeof obj.__name__ === 'string') {
  1899. data.__name__ = obj.__name__;
  1900. }
  1901. return data;
  1902. }
  1903. return null;
  1904. default:
  1905. return null;
  1906. }
  1907. }
  1908. static toType(type) {
  1909. switch (type.kind()) {
  1910. case 'OptionalType': return `${pytorch.Utility.toType(type.getElementType())}?`;
  1911. case 'ListType': return `${pytorch.Utility.toType(type.getElementType())}[]`;
  1912. case 'BoolType': return 'boolean';
  1913. case 'IntType': return 'int64';
  1914. case 'FloatType': return 'float32';
  1915. case 'StringType': return 'string';
  1916. case 'ComplexType': return 'complex';
  1917. case 'NumberType': return 'scalar';
  1918. case 'TensorType': return 'tensor';
  1919. case 'TupleType': return `tuple<${type.elements().map((type) => pytorch.Utility.toType(type)).join(', ')}>`;
  1920. case 'DictType': return `map<${pytorch.Utility.toType(type.getKeyType())}, ${pytorch.Utility.toType(type.getValueType())}>`;
  1921. case 'DeviceObjType': return 'device';
  1922. case 'SymIntType': return 'SymInt';
  1923. case 'ScalarTypeType': return 'ScalarType';
  1924. case 'MemoryFormat': return 'MemoryFormat';
  1925. case 'Layout': return 'Layout';
  1926. case 'VarType': return type.annotation_str;
  1927. case 'NoneType': return 'None';
  1928. case 'AnyType': return 'object';
  1929. case 'AnyListType': return 'list';
  1930. case 'AnyTupleType': return 'tuple';
  1931. case 'ClassType': return type.annotation_str;
  1932. case 'EnumType': return type.annotation_str;
  1933. default: throw new pytorch.Error(`Unsupported type '${type.kind()}'.`);
  1934. }
  1935. }
  1936. static toString(ivalue) {
  1937. if (ivalue.isInt()) {
  1938. return ivalue.toInt();
  1939. }
  1940. if (ivalue.isDouble()) {
  1941. return ivalue.toDouble();
  1942. }
  1943. if (ivalue.isEnum()) {
  1944. return ivalue.toEnumHolder().name();
  1945. }
  1946. if (ivalue.isList()) {
  1947. return ivalue.toList().map((item) => pytorch.Utility.toString(item));
  1948. }
  1949. throw new pytorch.Error(`Unsupported IValue '${ivalue.tag}.`);
  1950. }
  1951. static constant(node, name) {
  1952. const kind = node.kindOf(name);
  1953. switch (kind) {
  1954. case 's': return node.s(name);
  1955. case 'i': return node.i(name);
  1956. case 'f': return node.f(name);
  1957. case 'ss': return node.ss(name);
  1958. case 'ival': return node.ival(name);
  1959. default: throw new pytorch.Error(`Unsupported attribute kind '${kind}'.`);
  1960. }
  1961. }
  1962. static unique(value) {
  1963. return value.hasDebugName() ? `%${value.debugName().toString()}` : `%${value.unique().toString()}`;
  1964. }
  1965. static isObject(obj) {
  1966. const type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
  1967. switch (type) {
  1968. case '__torch__.torch.classes.xnnpack.LinearOpContext':
  1969. case '__torch__.torch.classes.xnnpack.Conv2dOpContext':
  1970. case '__torch__.torch.classes.xnnpack.TransposeConv2dOpContext':
  1971. case '__torch__.torch.classes.rnn.CellParamsBase':
  1972. case '__torch__.torch.classes.rnn.CellParamsBase[]':
  1973. case '__torch__.torch.classes.quantized.LinearPackedParamsBase':
  1974. case '__torch__.torch.classes.quantized.Conv2dPackedParamsBase':
  1975. case '__torch__.torch.classes.quantized.Conv3dPackedParamsBase':
  1976. case '__torch__.torch.classes.quantized.EmbeddingPackedParamsBase':
  1977. return true;
  1978. default:
  1979. return false;
  1980. }
  1981. }
  1982. static isSubclass(value, name) {
  1983. if (value && value.__module__ && value.__name__) {
  1984. return name === `${value.__module__}.${value.__name__}`;
  1985. } else if (value && value.__bases__) {
  1986. return value.__bases__.some((obj) => pytorch.Utility.isSubclass(obj, name));
  1987. }
  1988. return false;
  1989. }
  1990. static isInstance(value, name) {
  1991. return value && value.__class__ ? pytorch.Utility.isSubclass(value.__class__, name) : false;
  1992. }
  1993. static format(name, value) {
  1994. // https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
  1995. // kProducedFileFormatVersion
  1996. const versions = new Map([
  1997. ['1', 'v1.3'],
  1998. ['2', 'v1.5'], // 7a2889b014ce36fcc333b2c6de6f29f976652f84 (#28122)
  1999. ['3', 'v1.6'], // 2ec6a30722b0ef85632a2f3e7ce6f80da403008a (#36085)
  2000. ['4', 'v1.6'], // 95489b590f00801bdee7f41783f30874883cf6bb (#38620)
  2001. ['5', 'v1.7'], // cb26661fe4faf26386703180a9045e6ac6d157df (#40364)
  2002. ['6', 'v1.9'], // 3ee7637ffa50df0d9b231c7b40778ac1c390bf4a (#59714)
  2003. ['7', 'v1.10'], // 880098a7e34a20628f960daa8eab0eb1ad566c39 (#63651)
  2004. ['8', 'v1.11'], // b28e696516a7f0c7a6ead6da967590ce6c1d6698 (#71486)
  2005. ['9', 'v1.11'], // 8757e21c6a4fc00e83539aa7f9c28eb11eff53c1 (#72051)
  2006. ['10', 'v1.12'] // 4f8b986e28736b59bc46cd0873a0f36fdaa6f5b8 (#61439)
  2007. ]);
  2008. value = value.toString();
  2009. if (!versions.has(value)) {
  2010. throw new pytorch.Error(`Unsupported '${name}' version '${value}'.`);
  2011. }
  2012. return `${name} ${versions.get(value)}`;
  2013. }
  2014. static weights(obj) {
  2015. let type = obj && obj.__class__ && obj.__class__.__module__ && obj.__class__.__name__ ? `${obj.__class__.__module__}.${obj.__class__.__name__}` : null;
  2016. if (type === 'torch.jit._script.RecursiveScriptModule') {
  2017. type = obj._c._type();
  2018. const target = {};
  2019. for (let i = 0; i < type.numAttributes(); i++) {
  2020. const k = type.getAttributeName(i);
  2021. target[k] = obj.__getattr__(k);
  2022. }
  2023. type = obj._c.qualified_name;
  2024. obj = target;
  2025. } else if (type && type !== 'builtins.dict' && type !== 'builtins.object' && type !== 'collections.OrderedDict' && type !== 'torch.nn.modules.module.Module' && type !== '__torch__.Module') {
  2026. return null;
  2027. }
  2028. if (pytorch.Utility.isTensor(obj)) {
  2029. return null;
  2030. }
  2031. if (obj instanceof Map === false && obj && !Array.isArray(obj) && Object(obj) === obj) {
  2032. const entries = Object.entries(obj);
  2033. const named = entries.filter(([name, value]) => (typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1)) && pytorch.Utility.isTensor(value));
  2034. if (named.length > 0 && (named.length / entries.length) >= 0.8) {
  2035. obj = new Map(entries);
  2036. }
  2037. }
  2038. if (obj instanceof Map) {
  2039. const entries = Array.from(obj).filter(([name]) => name !== '_metadata');
  2040. const names = entries.filter(([name]) => typeof name === 'string' && (name.indexOf('.') !== -1 || name.indexOf('|') !== -1));
  2041. if (names.length > 1 && (names.length / entries.length) >= 0.8 &&
  2042. (entries.every(([, value]) => !pytorch.Utility.isInstance(value, 'builtins.dict') || Array.from(value.values()).every((value) => !pytorch.Utility.isTensor(value)))) &&
  2043. (!entries.every(([, value]) => Array.isArray(value)))) {
  2044. const modules = new Map();
  2045. for (const [name, value] of entries) {
  2046. const separator = name.indexOf('.') === -1 && name.indexOf('|') !== -1 ? '|' : '.';
  2047. const path = name.split(separator);
  2048. let property = path.pop();
  2049. if (path.length > 1 && path[path.length - 1] === '_packed_params') {
  2050. property = `${path.pop()}.${property}`;
  2051. }
  2052. const key = path.join(separator);
  2053. if (!modules.has(key)) {
  2054. modules.set(key, {});
  2055. }
  2056. const module = modules.get(key);
  2057. if (pytorch.Utility.isTensor(value)) {
  2058. value.__name__ = name;
  2059. }
  2060. module[property] = value;
  2061. }
  2062. return modules;
  2063. }
  2064. }
  2065. if (obj && !Array.isArray(obj) && Object(obj) === obj) {
  2066. const modules = new Map();
  2067. const entries = obj instanceof Map ? Array.from(obj) : Object.entries(obj);
  2068. if (entries.length > 0 && entries) {
  2069. for (const [key, value] of entries) {
  2070. const name = key.toString();
  2071. if (!value || Object(value) !== value || pytorch.Utility.isTensor(value) || ArrayBuffer.isView(value) || value._modules instanceof Map) {
  2072. return null;
  2073. }
  2074. if (!modules.has(name)) {
  2075. modules.set(name, {});
  2076. }
  2077. const module = modules.get(name);
  2078. let tensor = false;
  2079. const entries = value instanceof Map ? value : new Map(Object.entries(value));
  2080. for (const [name, value] of entries) {
  2081. if (typeof name !== 'string') {
  2082. return null;
  2083. }
  2084. if (name.indexOf('.') !== -1) {
  2085. return null;
  2086. }
  2087. if (name === '_metadata') {
  2088. continue;
  2089. }
  2090. if (typeof value === 'string' || typeof value === 'number') {
  2091. module[name] = value;
  2092. continue;
  2093. }
  2094. if (pytorch.Utility.isTensor(value)) {
  2095. value.__name__ = name;
  2096. module[name] = value;
  2097. tensor = true;
  2098. }
  2099. }
  2100. if (!tensor) {
  2101. return null;
  2102. }
  2103. }
  2104. return modules;
  2105. }
  2106. }
  2107. return null;
  2108. }
  2109. static isMetadataObject(obj) {
  2110. if (pytorch.Utility.isInstance(obj, 'collections.OrderedDict')) {
  2111. for (const value of obj.values()) {
  2112. if (pytorch.Utility.isInstance(value, 'builtins.dict')) {
  2113. const entries = Array.from(value);
  2114. if (entries.length !== 1 && entries[0] !== 'version' && entries[1] !== 1) {
  2115. return false;
  2116. }
  2117. }
  2118. }
  2119. return true;
  2120. }
  2121. return false;
  2122. }
  2123. };
  2124. nnapi.SerializedModel = class {
  2125. constructor(serialized_model, buffers) {
  2126. const reader = base.BinaryReader.open(serialized_model);
  2127. this.version = reader.int32();
  2128. if (this.version !== 1) {
  2129. throw new pytorch.Error('Invalid NNAPI serialized model version.');
  2130. }
  2131. const operands = new Array(reader.int32());
  2132. const values = new Array(reader.int32());
  2133. this.operations = new Array(reader.int32());
  2134. this.inputs = new Array(reader.int32());
  2135. this.outputs = new Array(reader.int32());
  2136. const data_types = new Map([
  2137. [0, 'float32'],
  2138. [1, 'int32'],
  2139. [2, 'uint32'],
  2140. [3, 'float32[]'],
  2141. [4, 'int32[]'],
  2142. [5, 'quant8_asymm[]'],
  2143. [6, 'boolean'],
  2144. [7, 'quant16_symm[]'],
  2145. [8, 'float16[]'],
  2146. [9, 'boolean[]'],
  2147. [10, 'float16'],
  2148. [11, 'quant8_symm_per_channel[]'],
  2149. [12, 'quant16_asymm[]'],
  2150. [13, 'quant8_symm[]'],
  2151. [14, 'quant8_asymm_signed[]'],
  2152. [16, 'model']
  2153. ]);
  2154. for (let i = 0; i < operands.length; i++) {
  2155. const data_type = reader.int32();
  2156. operands[i] = {
  2157. index: i,
  2158. data_type: data_types.has(data_type) ? data_types.get(data_type) : data_type,
  2159. dimensions: new Array(reader.uint32()),
  2160. scale: reader.float32(),
  2161. zero_point: reader.int32()
  2162. };
  2163. }
  2164. for (let i = 0; i < values.length; i++) {
  2165. values[i] = {
  2166. index: reader.int32(),
  2167. source_type: reader.int32(),
  2168. source_length: reader.uint32()
  2169. };
  2170. }
  2171. for (let i = 0; i < this.operations.length; i++) {
  2172. this.operations[i] = {
  2173. index: reader.int32(),
  2174. identifier: i,
  2175. inputs: new Array(reader.uint32()),
  2176. outputs: new Array(reader.uint32())
  2177. };
  2178. }
  2179. for (const operand of operands) {
  2180. for (let i = 0; i < operand.dimensions.length; i++) {
  2181. operand.dimensions[i] = reader.uint32();
  2182. }
  2183. }
  2184. for (const value of values) {
  2185. const index = value.index;
  2186. const operand = operands[index];
  2187. switch (value.source_type) {
  2188. case 0: { // immediate
  2189. switch (operand.data_type) {
  2190. case 'boolean':
  2191. operand.value = reader.byte() ? true : false;
  2192. reader.skip(3);
  2193. break;
  2194. case 'int32':
  2195. operand.value = reader.int32();
  2196. break;
  2197. case 'float32':
  2198. operand.value = reader.float32();
  2199. break;
  2200. case 'int32[]':
  2201. operand.data = reader.read(value.source_length);
  2202. break;
  2203. case 'float32[]':
  2204. operand.data = reader.read(value.source_length);
  2205. break;
  2206. default:
  2207. throw new pytorch.Error(`Unsupported NNAPI operand type '${operand.data_type}'.`);
  2208. }
  2209. break;
  2210. }
  2211. case 2: { // numbered buffer
  2212. if (value.source_length !== 12) {
  2213. throw new pytorch.Error('Invalid NNAPI numbered buffer source length.');
  2214. }
  2215. const number = reader.uint32();
  2216. const offset = reader.uint32();
  2217. const operand_length = reader.uint32();
  2218. if (number < buffers.length && buffers[number].data) {
  2219. const storage = buffers[number];
  2220. const data = storage.data && storage.data.peek ? storage.data.peek() : storage.data;
  2221. operand.data = data.slice(offset, operand_length);
  2222. }
  2223. break;
  2224. }
  2225. case 3: { // numbered memory
  2226. throw new pytorch.Error('NNAPI numbered memory buffer not implemented.');
  2227. }
  2228. default: {
  2229. throw new pytorch.Error('Unsupported NNAPI value source type.');
  2230. }
  2231. }
  2232. }
  2233. for (const operation of this.operations) {
  2234. for (let i = 0; i < operation.inputs.length; i++) {
  2235. const index = reader.uint32();
  2236. operation.inputs[i] = operands[index];
  2237. }
  2238. for (let i = 0; i < operation.outputs.length; i++) {
  2239. const index = reader.uint32();
  2240. operation.outputs[i] = operands[index];
  2241. }
  2242. }
  2243. for (let i = 0; i < this.inputs.length; i++) {
  2244. const index = reader.uint32();
  2245. this.inputs[i] = operands[index];
  2246. }
  2247. for (let i = 0; i < this.outputs.length; i++) {
  2248. const index = reader.uint32();
  2249. this.outputs[i] = operands[index];
  2250. }
  2251. if (reader.position !== reader.length) {
  2252. throw new pytorch.Error('Invalid NNAPI serialized model length.');
  2253. }
  2254. }
  2255. };
  2256. nnapi.Graph = class {
  2257. constructor(model) {
  2258. this.name = 'torch.classes._nnapi.Compilation';
  2259. this.nodes = [];
  2260. this.inputs = [];
  2261. this.outputs = [];
  2262. const values = new Map();
  2263. values.map = (operand) => {
  2264. if (!values.has(operand.index)) {
  2265. const name = operand.index.toString();
  2266. const dimensions = operand.dimensions;
  2267. const shape = new pytorch.TensorShape(dimensions);
  2268. let dataType = operand.data_type.replace('[]', '');
  2269. let quantization = null;
  2270. switch (dataType) {
  2271. case 'quant8_asymm':
  2272. case 'quant8_symm_per_channel':
  2273. case 'quant8_symm':
  2274. case 'quant8_asymm_signed[]':
  2275. case 'quant16_asymm':
  2276. case 'quant16_symm':
  2277. quantization = dataType;
  2278. dataType = dataType.indexOf('16') === -1 ? 'uint8' : 'uint16';
  2279. break;
  2280. default:
  2281. break;
  2282. }
  2283. const type = new pytorch.TensorType(dataType, shape);
  2284. let initializer = null;
  2285. if (operand.data) {
  2286. const size = dimensions.reduce((a, b) => a * b, 1);
  2287. const tensor = {
  2288. dtype: { __reduce__: () => dataType },
  2289. size: () => dimensions,
  2290. stride: () => null,
  2291. storage_offset: () => 0,
  2292. storage: () => ({
  2293. dtype: { __reduce__: () => type.dataType },
  2294. data: operand.data, size: () => size
  2295. })
  2296. };
  2297. const context = new pytorch.Context();
  2298. initializer = new pytorch.Tensor(context, null, tensor);
  2299. }
  2300. if (quantization || (operand.scale !== undefined && operand.scale !== 0) || (operand.zero_point !== undefined && operand.zero_point !== 0)) {
  2301. quantization = {
  2302. type: quantization || 'linear',
  2303. scale: [operand.scale],
  2304. offset: [operand.zero_point]
  2305. };
  2306. }
  2307. const value = new pytorch.Value(name, type, quantization, initializer);
  2308. values.set(operand.index, value);
  2309. }
  2310. return values.get(operand.index);
  2311. };
  2312. const metadata = new nnapi.Metadata();
  2313. for (const operation of model.operations) {
  2314. const node = new nnapi.Node(metadata, operation, values);
  2315. this.nodes.push(node);
  2316. }
  2317. for (let i = 0; i < model.inputs.length; i++) {
  2318. const name = i.toString();
  2319. const operand = model.inputs[i];
  2320. const argument = new pytorch.Argument(name, [values.map(operand)]);
  2321. this.inputs.push(argument);
  2322. }
  2323. for (let i = 0; i < model.outputs.length; i++) {
  2324. const name = i.toString();
  2325. const operand = model.outputs[i];
  2326. const argument = new pytorch.Argument(name, [values.map(operand)]);
  2327. this.outputs.push(argument);
  2328. }
  2329. }
  2330. };
  2331. nnapi.Node = class {
  2332. constructor(metadata, operation, values) {
  2333. const signature = (operation.inputs || []).map((input) => input.data_type);
  2334. this.name = '';
  2335. this.type = metadata.type(operation.index, signature);
  2336. this.inputs = [];
  2337. this.outputs = [];
  2338. this.attributes = [];
  2339. this.chain = [];
  2340. if (operation.identifier !== undefined) {
  2341. this.identifier = operation.identifier.toString();
  2342. }
  2343. if (Array.isArray(operation.inputs)) {
  2344. const inputs = this.type.inputs;
  2345. for (let i = 0; i < operation.inputs.length; i++) {
  2346. const name = i < inputs.length ? inputs[i].name : i.toString();
  2347. const operand = operation.inputs[i];
  2348. if (operand.dimensions.length > 0) {
  2349. const value = values.map(operand);
  2350. const argument = new pytorch.Argument(name, [value]);
  2351. this.inputs.push(argument);
  2352. } else if (name === 'activation') {
  2353. const activation = new Map([[1, 19], [2, 20], [3, 21]]).get(operand.value) || 0;
  2354. if (activation !== 0) {
  2355. this.chain.push(new nnapi.Node(metadata, { index: activation }));
  2356. }
  2357. } else {
  2358. const attribute = new pytorch.Argument(name, operand.value, operand.data_type, false);
  2359. this.inputs.push(attribute);
  2360. }
  2361. }
  2362. }
  2363. if (Array.isArray(operation.outputs)) {
  2364. const outputs = this.type.outputs;
  2365. for (let i = 0; i < operation.outputs.length; i++) {
  2366. const name = i < outputs.length ? outputs[i].name : i.toString();
  2367. const operand = operation.outputs[i];
  2368. const value = values.map(operand);
  2369. const argument = new pytorch.Argument(name, [value]);
  2370. this.outputs.push(argument);
  2371. }
  2372. }
  2373. }
  2374. };
  2375. nnapi.Metadata = class {
  2376. constructor() {
  2377. this._types = new Map();
  2378. // https://developer.android.com/ndk/reference/group/neural-networks
  2379. // https://github.com/pytorch/pytorch/commits/master/torch/backends/_nnapi/serializer.py
  2380. this.register(0, 'ADD', '', ['A', 'B'], [['activation', 'int32']], ['C']);
  2381. this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
  2382. this.register(1, 'AVERAGE_POOL_2D', 'Pool', ['input'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['filter_x', 'int32'], ['filter_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean']], ['output']);
  2383. this.register(2, 'CONCATENATION');
  2384. this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
  2385. this.register(3, 'CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
  2386. this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_left', 'int32'], ['padding_right', 'int32'], ['padding_top', 'int32'], ['padding_bottom', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
  2387. this.register(4, 'DEPTHWISE_CONV_2D', 'Layer', ['input', 'weights', 'bias'], [['padding_scheme', 'int32'], ['stride_x', 'int32'], ['stride_y', 'int32'], ['activation', 'int32'], ['nchw', 'boolean'], ['dilation_width', 'int32'], ['dilation_height', 'int32']], ['output']);
  2388. this.register(5, 'DEPTH_TO_SPACE');
  2389. this.register(6, 'DEQUANTIZE');
  2390. this.register(7, 'EMBEDDING_LOOKUP');
  2391. this.register(8, 'FLOOR');
  2392. this.register(9, 'FULLY_CONNECTED', 'Layer', ['input', 'weights', 'bias'], [['activation', 'int32']], ['output']);
  2393. this.register(10, 'HASHTABLE_LOOKUP');
  2394. this.register(11, 'L2_NORMALIZATION');
  2395. this.register(12, 'L2_POOL_2D', 'Pool');
  2396. this.register(13, 'LOCAL_RESPONSE_NORMALIZATION');
  2397. this.register(14, 'LOGISTIC');
  2398. this.register(15, 'LSH_PROJECTION');
  2399. this.register(16, 'LSTM', 'Layer');
  2400. this.register(17, 'MAX_POOL_2D', 'Pool');
  2401. this.register(18, 'MUL');
  2402. this.register(19, 'RELU', 'Activation', ['input'], [], ['output']);
  2403. this.register(20, 'RELU1', 'Activation');
  2404. this.register(21, 'RELU6', 'Activation');
  2405. this.register(22, 'RESHAPE', 'Shape', ['input', 'shape'], [], ['output']);
  2406. this.register(23, 'RESIZE_BILINEAR');
  2407. this.register(24, 'RNN', 'Layer');
  2408. this.register(25, 'SOFTMAX', 'Activation');
  2409. this.register(26, 'SPACE_TO_DEPTH');
  2410. this.register(27, 'SVDF');
  2411. this.register(28, 'TANH');
  2412. this.register(29, 'BATCH_TO_SPACE_ND');
  2413. this.register(30, 'DIV');
  2414. this.register(31, 'MEAN');
  2415. this.register(32, 'PAD');
  2416. this.register(33, 'SPACE_TO_BATCH_ND');
  2417. this.register(34, 'SQUEEZE');
  2418. this.register(35, 'STRIDED_SLICE');
  2419. this.register(36, 'SUB');
  2420. this.register(37, 'TRANSPOSE');
  2421. this.register(38, 'ABS');
  2422. this.register(39, 'ARGMAX');
  2423. this.register(40, 'ARGMIN');
  2424. this.register(41, 'AXIS_ALIGNED_BBOX_TRANSFORM');
  2425. this.register(42, 'BIDIRECTIONAL_SEQUENCE_LSTM');
  2426. this.register(43, 'BIDIRECTIONAL_SEQUENCE_RNN');
  2427. this.register(44, 'BOX_WITH_NMS_LIMIT');
  2428. this.register(45, 'CAST');
  2429. this.register(46, 'CHANNEL_SHUFFLE');
  2430. this.register(47, 'DETECTION_POSTPROCESSING');
  2431. this.register(48, 'EQUAL');
  2432. this.register(49, 'EXP');
  2433. this.register(50, 'EXPAND_DIMS');
  2434. this.register(51, 'GATHER');
  2435. this.register(52, 'GENERATE_PROPOSALS');
  2436. this.register(53, 'GREATER');
  2437. this.register(54, 'GREATER_EQUAL');
  2438. this.register(55, 'GROUPED_CONV_2D');
  2439. this.register(56, 'HEATMAP_MAX_KEYPOINT');
  2440. this.register(57, 'INSTANCE_NORMALIZATION');
  2441. this.register(58, 'LESS');
  2442. this.register(59, 'LESS_EQUAL');
  2443. this.register(60, 'LOG');
  2444. this.register(61, 'LOGICAL_AND');
  2445. this.register(62, 'LOGICAL_NOT');
  2446. this.register(63, 'LOGICAL_OR');
  2447. this.register(64, 'LOG_SOFTMAX');
  2448. this.register(65, 'MAXIMUM');
  2449. this.register(66, 'MINIMUM');
  2450. this.register(67, 'NEG');
  2451. this.register(68, 'NOT_EQUAL');
  2452. this.register(69, 'PAD_V2');
  2453. this.register(70, 'POW');
  2454. this.register(71, 'PRELU');
  2455. this.register(72, 'QUANTIZE');
  2456. this.register(73, 'QUANTIZED_16BIT_LSTM');
  2457. this.register(74, 'RANDOM_MULTINOMIAL');
  2458. this.register(75, 'REDUCE_ALL');
  2459. this.register(76, 'REDUCE_ANY');
  2460. this.register(77, 'REDUCE_MAX');
  2461. this.register(78, 'REDUCE_MIN');
  2462. this.register(79, 'REDUCE_PROD');
  2463. this.register(80, 'REDUCE_SUM');
  2464. this.register(81, 'ROI_ALIGN');
  2465. this.register(82, 'ROI_POOLING');
  2466. this.register(83, 'RSQRT');
  2467. this.register(84, 'SELECT');
  2468. this.register(85, 'SIN');
  2469. this.register(86, 'SLICE');
  2470. this.register(87, 'SPLIT');
  2471. this.register(88, 'SQRT');
  2472. this.register(89, 'TILE');
  2473. this.register(90, 'TOPK_V2');
  2474. this.register(91, 'TRANSPOSE_CONV_2D', 'Layer');
  2475. this.register(92, 'UNIDIRECTIONAL_SEQUENCE_LSTM', 'Layer');
  2476. this.register(93, 'UNIDIRECTIONAL_SEQUENCE_RNN', 'Layer');
  2477. this.register(94, 'RESIZE_NEAREST_NEIGHBOR');
  2478. this.register(95, 'QUANTIZED_LSTM', 'Layer');
  2479. this.register(96, 'IF');
  2480. this.register(97, 'WHILE');
  2481. this.register(98, 'ELU', 'Activation');
  2482. this.register(99, 'HARD_SWISH', 'Activation');
  2483. this.register(100, 'FILL');
  2484. this.register(101, 'RANK');
  2485. }
  2486. register(index, name, category, inputs, attributes, outputs) {
  2487. inputs = inputs || [];
  2488. outputs = outputs || [];
  2489. attributes = attributes || [];
  2490. const type = {};
  2491. type.name = name;
  2492. type.inputs = inputs.map((name) => ({ name, type: 'Tensor' }));
  2493. type.inputs = type.inputs.concat(attributes.map(([name, type]) => ({ name, type })));
  2494. type.outputs = outputs.map((name) => ({ name, type: 'Tensor' }));
  2495. if (category) {
  2496. type.category = category;
  2497. }
  2498. if (!this._types.has(index)) {
  2499. this._types.set(index, []);
  2500. }
  2501. this._types.get(index).push(type);
  2502. }
  2503. type(index, signature) {
  2504. if (!this._types.has(index)) {
  2505. this._types.set(index, { name: index.toString(), inputs: [], outputs: [], attributes: [] });
  2506. }
  2507. const types = this._types.get(index);
  2508. for (const type of types) {
  2509. const inputs = type.inputs;
  2510. if (signature.length < inputs.length) {
  2511. if (inputs.every((input, i) => input.type === undefined || input.type === 'Tensor' || input.type === signature[i])) {
  2512. return type;
  2513. }
  2514. }
  2515. }
  2516. return types[0];
  2517. }
  2518. };
  2519. pytorch.Metadata = class {
  2520. static async open(context) {
  2521. if (!pytorch.Metadata._metadata) {
  2522. let data = null;
  2523. try {
  2524. data = await context.request('pytorch-metadata.json');
  2525. } catch {
  2526. // continue regardless of error
  2527. }
  2528. pytorch.Metadata._metadata = new pytorch.Metadata(data);
  2529. }
  2530. return pytorch.Metadata._metadata;
  2531. }
  2532. constructor(data) {
  2533. this._types = new Map();
  2534. this._attributes = new Map();
  2535. this._index = new Map();
  2536. if (data) {
  2537. const items = JSON.parse(data);
  2538. for (const item of items) {
  2539. const index = item.name.indexOf('(');
  2540. const key = index === -1 ? item.name : item.name.substring(0, index);
  2541. this._types.set(key, item);
  2542. }
  2543. }
  2544. }
  2545. add(name, value) {
  2546. this._types.set(name, value);
  2547. }
  2548. type(name) {
  2549. return this._types.get(name);
  2550. }
  2551. attribute(type, name) {
  2552. const key = `${type}:${name}`;
  2553. if (!this._attributes.has(key)) {
  2554. this._attributes.set(key, null);
  2555. const metadata = this.type(type);
  2556. if (metadata) {
  2557. if (metadata.inputs) {
  2558. for (const input of metadata.inputs) {
  2559. this._attributes.set(`${type}:${input.name}`, input);
  2560. }
  2561. }
  2562. if (metadata.attributes) {
  2563. for (const attribute of metadata.attributes) {
  2564. this._attributes.set(`${type}:${attribute.name}`, attribute);
  2565. }
  2566. }
  2567. }
  2568. }
  2569. return this._attributes.get(key);
  2570. }
  2571. register(execution) {
  2572. const torch = execution.register('torch');
  2573. const registry = torch._C.getRegistry();
  2574. const modules = new Set();
  2575. for (const [name, type] of this._types) {
  2576. if (name.indexOf('::') !== -1) {
  2577. const schema = torch.FunctionSchema.parse(type.name);
  2578. if (type.category) {
  2579. schema.category = type.category;
  2580. }
  2581. schema.setAliasAnalysis('FROM_SCHEMA');
  2582. const op = new torch._C.Operator(schema);
  2583. registry.registerOperator(op);
  2584. modules.add(type.name.split('::')[0]);
  2585. }
  2586. }
  2587. for (const module of modules) {
  2588. const namespace = new torch._ops._OpNamespace(module);
  2589. execution.register(`torch.ops.${module}`, namespace);
  2590. }
  2591. }
  2592. };
  2593. numpy.Tensor = class {
  2594. constructor(array) {
  2595. this.type = new numpy.TensorType(array.dtype.__name__, new numpy.TensorShape(array.shape));
  2596. this.stride = array.strides.map((stride) => stride / array.itemsize);
  2597. this.encoding = this.type.dataType === 'string' || this.type.dataType === 'object' ? '|' : array.dtype.byteorder;
  2598. this.values = this.type.dataType === 'string' || this.type.dataType === 'object' || this.type.dataType === 'void' ? array.flatten().tolist() : array.tobytes();
  2599. }
  2600. };
  2601. numpy.TensorType = class {
  2602. constructor(dataType, shape) {
  2603. this.dataType = dataType || '?';
  2604. this.shape = shape;
  2605. }
  2606. toString() {
  2607. return this.dataType + this.shape.toString();
  2608. }
  2609. };
  2610. numpy.TensorShape = class {
  2611. constructor(dimensions) {
  2612. this.dimensions = dimensions;
  2613. }
  2614. toString() {
  2615. return this.dimensions && this.dimensions.length > 0 ? `[${this.dimensions.join(',')}]` : '';
  2616. }
  2617. };
  2618. pytorch.Error = class extends Error {
  2619. constructor(message) {
  2620. super(message);
  2621. this.name = 'Error loading PyTorch model.';
  2622. }
  2623. };
  2624. export const Metadata = pytorch.Metadata;
  2625. export const ModelFactory = pytorch.ModelFactory;