pytorch.js 148 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467
  1. /* jshint esversion: 6 */
  2. // Experimental
  3. var pytorch = pytorch || {};
  4. var base = base || require('./base');
  5. pytorch.ModelFactory = class {
  6. match(context) {
  7. const identifier = context.identifier;
  8. const extension = identifier.split('.').pop().toLowerCase();
  9. if ([ 'pth', 'pt', 'pt1', 'pkl', 'bin', 'model', 'h5', 'pb', 't7', 'dms', 'ckpt', 'zip' ].indexOf(extension) !== -1 || identifier.toLowerCase().endsWith('.tar')) {
  10. if (pytorch.Container.open(context)) {
  11. return true;
  12. }
  13. }
  14. return false;
  15. }
  16. open(context, host) {
  17. const identifier = context.identifier;
  18. return host.require('./pickle').then((pickle) => {
  19. return host.require('./python').then((python) => {
  20. return pytorch.Metadata.open(host).then((metadata) => {
  21. try {
  22. const container = pytorch.Container.open(context, metadata, pickle, python, (error, fatal) => {
  23. const message = error && error.message ? error.message : error.toString();
  24. host.exception(new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'."), fatal);
  25. });
  26. return new pytorch.Model(metadata, container);
  27. }
  28. catch (error) {
  29. host.exception(error, false);
  30. const message = error && error.message ? error.message : error.toString();
  31. throw new pytorch.Error(message.replace(/\.$/, '') + " in '" + identifier + "'.");
  32. }
  33. });
  34. });
  35. });
  36. }
  37. };
  38. pytorch.Model = class {
  39. constructor(metadata, container) {
  40. this._format = container.format;
  41. this._producer = container.producer || '';
  42. this._graphs = [ new pytorch.Graph(metadata, container) ];
  43. }
  44. get format() {
  45. return this._format;
  46. }
  47. get graphs() {
  48. return this._graphs;
  49. }
  50. };
  51. pytorch.Graph = class {
  52. constructor(metadata, container) {
  53. this._nodes = [];
  54. this._inputs = [];
  55. this._outputs = [];
  56. this._groups = true;
  57. this._littleEndian = container.littleEndian;
  58. if (container.format.startsWith('TorchScript ')) {
  59. this._name = container.name;
  60. const traced = container.trace();
  61. const initializers = new Map();
  62. if (container.data) {
  63. const queue = [ container.data ];
  64. while (queue.length > 0) {
  65. const module = queue.shift();
  66. for (const key of Object.keys(module)) {
  67. if (key !== '__module__' && key !== '__name__' && key !== '__parent__') {
  68. const obj = module[key];
  69. if (!Array.isArray(obj) && obj === Object(obj)) {
  70. if (pytorch.Utility.isTensor(obj)) {
  71. const parameter = obj;
  72. parameter.__parent__ = module;
  73. if (!parameter.initializer && parameter.storage) {
  74. parameter.initializer = new pytorch.Tensor(parameter.name, parameter, true);
  75. }
  76. if (parameter.__variable__ && parameter.__count__ === 1) {
  77. initializers.set(parameter.__variable__, parameter);
  78. }
  79. }
  80. else if (obj && obj.__module__ && obj.__name__) {
  81. obj.__parent__ = module;
  82. if (!obj.__id__) {
  83. obj.__id__ = key;
  84. }
  85. queue.push(obj);
  86. }
  87. }
  88. }
  89. }
  90. }
  91. }
  92. if (traced) {
  93. if (container.inputs) {
  94. for (const input of container.inputs) {
  95. this._inputs.push(new pytorch.Parameter(input, true, [
  96. new pytorch.Argument(input, null, null)
  97. ]));
  98. }
  99. }
  100. if (container.outputs) {
  101. for (const output of container.outputs) {
  102. this._outputs.push(new pytorch.Parameter(output, true, [
  103. new pytorch.Argument(output, null, null)
  104. ]));
  105. }
  106. }
  107. if (container.nodes) {
  108. for (const node of container.nodes) {
  109. const item = {
  110. type: node.type,
  111. node: node
  112. };
  113. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  114. }
  115. }
  116. }
  117. if (container.data) {
  118. this._loadScriptModule(metadata, container, container.data, initializers);
  119. }
  120. if (container.constants) {
  121. const obj = {
  122. type: 'torch.nn.Constants',
  123. attributes: [],
  124. inputs: [],
  125. outputs: [],
  126. };
  127. let index = 0;
  128. for (const constant of container.constants) {
  129. if (constant.__variable__ && constant.__count__ > 1 && constant.storage) {
  130. const initializer = new pytorch.Tensor(constant.name, constant, true);
  131. obj.inputs.push(new pytorch.Parameter('c' + index.toString(), false, [
  132. new pytorch.Argument(constant.__variable__, initializer.type, initializer)
  133. ]));
  134. obj.outputs.push(new pytorch.Parameter('c' + index.toString(), false, [
  135. new pytorch.Argument(constant.__variable__)
  136. ]));
  137. }
  138. index++;
  139. }
  140. if (obj.inputs.length > 0) {
  141. this._nodes.push(new pytorch.Node(metadata, '', obj, null));
  142. }
  143. }
  144. }
  145. else if (container.data) {
  146. const data = container.data;
  147. this._type = (data.__module__ && data.__name__) ? (data.__module__ + '.' + data.__name__) : '';
  148. const input = 'data';
  149. this._inputs.push(new pytorch.Parameter(input, true, [ new pytorch.Argument(input, null, null) ]));
  150. const outputs = this._loadModule(metadata, container.data, [], [ input ]);
  151. for (const output of outputs) {
  152. this._outputs.push(new pytorch.Parameter(output, true, [ new pytorch.Argument(output, null, null) ]));
  153. }
  154. }
  155. else if (container.state) {
  156. for (const state_group of container.state) {
  157. const attributes = state_group.attributes || [];
  158. const inputs = state_group.states.map((parameter) => {
  159. return new pytorch.Parameter(parameter.name, true,
  160. parameter.arguments.map((state) => {
  161. const tensor = new pytorch.Tensor(state.id, state.value, this._littleEndian);
  162. return new pytorch.Argument(state.id, null, tensor);
  163. }));
  164. });
  165. const obj = {
  166. name: state_group.name,
  167. type: state_group.type || 'torch.nn.Module',
  168. attributes: attributes,
  169. inputs: inputs,
  170. outputs: []
  171. };
  172. this._nodes.push(new pytorch.Node(metadata, '', obj, null));
  173. }
  174. }
  175. }
  176. _loadModule(metadata, parent, groups, inputs) {
  177. if (parent.__module__ &&
  178. !parent.__module__ === 'torch.nn.modules.container' &&
  179. (!parent._modules || parent._modules.length == 0)) {
  180. this._createNode(groups, '', parent, inputs);
  181. return [];
  182. }
  183. if (!parent._modules) {
  184. throw new pytorch.Error('Module does not contain modules.');
  185. }
  186. for (const module of parent._modules) {
  187. const key = module[0];
  188. const value = module[1];
  189. if (module && value) {
  190. const type = value.__module__ + '.' + value.__name__;
  191. switch (type) {
  192. case 'torch.nn.modules.container.Sequential':
  193. groups.push(key);
  194. inputs = this._loadModule(metadata, value, groups, inputs);
  195. groups.pop(key);
  196. break;
  197. case 'torchvision.models.densenet._Transition':
  198. case 'torchvision.models.resnet.Bottleneck':
  199. case 'torchvision.models.densenet._DenseBlock':
  200. case 'torchvision.models.densenet._DenseLayer':
  201. case 'torchvision.models.inception.BasicConv2d':
  202. case 'torchvision.models.inception.InceptionAux':
  203. case 'torchvision.models.inception.InceptionA':
  204. case 'torchvision.models.inception.InceptionB':
  205. case 'torchvision.models.inception.InceptionC':
  206. case 'torchvision.models.inception.InceptionD':
  207. case 'torchvision.models.inception.InceptionE': {
  208. groups.push(key);
  209. const node = this._createNode(metadata, groups, key, value, inputs, this._littleEndian);
  210. inputs = [ node.name ];
  211. groups.pop(key);
  212. break;
  213. }
  214. default: {
  215. const node = this._createNode(metadata, groups, key, value, inputs);
  216. inputs = [ node.name ];
  217. break;
  218. }
  219. }
  220. }
  221. }
  222. return inputs;
  223. }
  224. _createNode(metadata, groups, key, obj, args) {
  225. const type = obj.__module__ + '.' + obj.__name__;
  226. const schema = metadata.type(type);
  227. let inputSchema = [ { name: 'input'} ];
  228. if (schema && schema.inputs && schema.inputs.length > 0) {
  229. inputSchema = schema.inputs.slice();
  230. }
  231. const inputs = [
  232. new pytorch.Parameter(inputSchema.shift().name, true, args.map((argument) => {
  233. return new pytorch.Argument(argument, null, null);
  234. }))
  235. ];
  236. const parameters = obj._parameters || obj._buffers || [];
  237. for (const parameter of parameters) {
  238. const key = parameter[0];
  239. const value = parameter[1];
  240. let visible = true;
  241. let inputName = '';
  242. if (inputSchema.length > 0) {
  243. const input = inputSchema.shift();
  244. inputName = input.name;
  245. visible = input.visible === false ? false : true;
  246. }
  247. if (parameter && value && (value.data || value.storage)) {
  248. let initializer = null;
  249. if (value.data) {
  250. initializer = new pytorch.Tensor('', value.data, this._littleEndian);
  251. }
  252. else if (value.storage) {
  253. initializer = new pytorch.Tensor('', value, this._littleEndian);
  254. }
  255. inputs.push(new pytorch.Parameter(inputName || key, visible, [ new pytorch.Argument('', null, initializer) ]));
  256. }
  257. }
  258. const group = groups.join('/');
  259. const name = group ? (group + '/' + key) : key;
  260. const outputs = [ new pytorch.Parameter('output', true, [ new pytorch.Argument(name, null, null) ]) ];
  261. const attributes = [];
  262. for (const name of Object.keys(obj)) {
  263. if (!name.startsWith('_')) {
  264. attributes.push({ name: name, value: obj[name] });
  265. }
  266. }
  267. const item = {
  268. name: name,
  269. type: type,
  270. attributes: attributes,
  271. inputs: inputs,
  272. outputs: outputs
  273. };
  274. const node = new pytorch.Node(metadata, group, item, {});
  275. this._nodes.push(node);
  276. return node;
  277. }
  278. _loadScriptModule(metadata, container, module, initializers) {
  279. if (module) {
  280. if (pytorch.Graph._getParameters(module).length > 0 && !module.__hide__) {
  281. const item = { module: module };
  282. this._nodes.push(new pytorch.Node(metadata, '', item, initializers));
  283. }
  284. const submodules = pytorch.Graph._getSubmodules(module);
  285. for (const submodule of submodules) {
  286. this._loadScriptModule(metadata, container, submodule, initializers);
  287. }
  288. }
  289. }
  290. static _getParameters(module) {
  291. const parameters = [];
  292. if (module && module.__module__ && module.__name__) {
  293. for (const key of Object.keys(module)) {
  294. if (pytorch.Utility.isTensor(module[key])) {
  295. const parameter = module[key];
  296. parameter.__id__ = key;
  297. parameters.push(parameter);
  298. }
  299. }
  300. }
  301. return parameters;
  302. }
  303. static _getSubmodules(module) {
  304. const submodules = [];
  305. if (module && module.__module__ && module.__name__) {
  306. for (const key of Object.keys(module)) {
  307. if (!key.startsWith('__')) {
  308. const value = module[key];
  309. if (value && value.__module__ && value.__name__ && !pytorch.Utility.isTensor(value)) {
  310. submodules.push(value);
  311. }
  312. }
  313. }
  314. }
  315. return submodules;
  316. }
  317. get type() {
  318. return this._type;
  319. }
  320. get name() {
  321. return this._name;
  322. }
  323. get groups() {
  324. return this._groups;
  325. }
  326. get inputs() {
  327. return this._inputs;
  328. }
  329. get outputs() {
  330. return this._outputs;
  331. }
  332. get nodes() {
  333. return this._nodes;
  334. }
  335. };
  336. pytorch.Parameter = class {
  337. constructor(name, visible, args) {
  338. this._name = name;
  339. this._visible = visible;
  340. this._arguments = args;
  341. }
  342. get name() {
  343. return this._name;
  344. }
  345. get visible() {
  346. return this._visible;
  347. }
  348. get arguments() {
  349. return this._arguments;
  350. }
  351. };
  352. pytorch.Argument = class {
  353. constructor(name, type, initializer) {
  354. if (typeof name !== 'string') {
  355. throw new pytorch.Error("Invalid argument identifier '" + JSON.stringify(name) + "'.");
  356. }
  357. this._name = name;
  358. this._type = type;
  359. this._initializer = initializer;
  360. }
  361. get name() {
  362. return this._name;
  363. }
  364. get type() {
  365. if (this._initializer) {
  366. return this._initializer.type;
  367. }
  368. return this._type;
  369. }
  370. get initializer() {
  371. return this._initializer;
  372. }
  373. };
  374. pytorch.Node = class {
  375. constructor(metadata, group, item, initializers) {
  376. this._metadata = metadata;
  377. this._group = group || '';
  378. this._name = item.name || '';
  379. if (!item.module && !item.node) {
  380. this._type = item.type;
  381. this._inputs = item.inputs;
  382. this._outputs = item.outputs;
  383. this._attributes = item.attributes.map((attribute) => {
  384. const schema = metadata.attribute(this._type, attribute.name);
  385. return new pytorch.Attribute(schema, attribute.name, attribute.value);
  386. });
  387. }
  388. else {
  389. this._attributes = [];
  390. this._inputs = [];
  391. this._outputs = [];
  392. let module = item.module;
  393. if (module) {
  394. this._type = 'torch.nn.modules.module.Module';
  395. for (const parameter of pytorch.Graph._getParameters(module)) {
  396. this._inputs.push(new pytorch.Parameter(parameter.__id__, true, [
  397. new pytorch.Argument('', null, parameter.initializer || null)
  398. ]));
  399. if (parameter.__variable__) {
  400. this._outputs.push(new pytorch.Parameter(parameter.__id__, true, [
  401. new pytorch.Argument(parameter.__variable__, null, null)
  402. ]));
  403. }
  404. }
  405. }
  406. if (item.node) {
  407. this._type = item.type;
  408. const schema = metadata.type(this._type);
  409. module = null;
  410. let match = true;
  411. let count = 0;
  412. for (const input of item.node.inputs) {
  413. for (const argument of input) {
  414. const parameter = initializers.get(argument.id);
  415. if (parameter) {
  416. if (parameter.__parent__ && (module == null || module == parameter.__parent__)) {
  417. module = parameter.__parent__;
  418. count++;
  419. }
  420. else {
  421. match = false;
  422. break;
  423. }
  424. }
  425. }
  426. if (!match) {
  427. break;
  428. }
  429. }
  430. if (module) {
  431. const params = pytorch.Graph._getParameters(module).filter((p) => p.__id__ !== 'num_batches_tracked');
  432. if (params.length == count && match) {
  433. module.__hide__ = true;
  434. for (const input of item.node.inputs) {
  435. for (const argument of input) {
  436. const parameter = initializers.get(argument.id);
  437. if (parameter && parameter.initializer) {
  438. argument.initializer = parameter.initializer;
  439. }
  440. }
  441. }
  442. }
  443. else {
  444. module = null;
  445. }
  446. }
  447. for (let inputIndex = 0; inputIndex < item.node.inputs.length; inputIndex++) {
  448. let inputName = inputIndex.toString();
  449. if (schema && schema.inputs && schema.inputs.length > inputIndex) {
  450. inputName = schema.inputs[inputIndex].name;
  451. }
  452. this._inputs.push(new pytorch.Parameter(inputName, true,
  453. item.node.inputs[inputIndex].map((input) => new pytorch.Argument(input.id, null, input.initializer || null))
  454. ));
  455. }
  456. for (let outputIndex = 0; outputIndex < item.node.outputs.length; outputIndex++) {
  457. let outputName = outputIndex.toString();
  458. if (schema && schema.outputs && schema.outputs.length > outputIndex) {
  459. outputName = schema.outputs[outputIndex].name;
  460. }
  461. this._outputs.push(new pytorch.Parameter(outputName, true,
  462. item.node.outputs[outputIndex].map((output) => new pytorch.Argument(output.id, null, null))
  463. ));
  464. }
  465. for (const attribute of item.node.attributes) {
  466. const name = attribute.name;
  467. const value = attribute.value;
  468. const schema = metadata.attribute(this._type, name);
  469. this._attributes.push(new pytorch.Attribute(schema, name, value));
  470. }
  471. }
  472. if (module) {
  473. if (module.__id__) {
  474. let current = module;
  475. this._name = current.__id__;
  476. while (current.__parent__ != null) {
  477. current = current.__parent__;
  478. if (!current.__parent__ && !current.__id__) {
  479. break;
  480. }
  481. this._name = [ current.__id__, this._name ].join('.');
  482. }
  483. }
  484. }
  485. }
  486. }
  487. get name() {
  488. return this._name;
  489. }
  490. get group() {
  491. return this._group;
  492. }
  493. get type() {
  494. const index = this._type.indexOf(':');
  495. return index === -1 ? this._type : this._type.substring(0, index);
  496. }
  497. get metadata() {
  498. return this._metadata.type(this._type);
  499. }
  500. get function() {
  501. return this._type.startsWith('torch.nn.modules.') && this._type !== 'torch.nn.modules.module.Module';
  502. }
  503. get attributes() {
  504. return this._attributes;
  505. }
  506. get inputs() {
  507. return this._inputs;
  508. }
  509. get outputs() {
  510. return this._outputs;
  511. }
  512. };
  513. pytorch.Attribute = class {
  514. constructor(schema, name, value) {
  515. this._name = name;
  516. this._value = value;
  517. if (this._name === 'training') {
  518. this._visible = false;
  519. this._type = 'boolean';
  520. return;
  521. }
  522. if (value && value.type) {
  523. switch (value.type) {
  524. case 'number':
  525. this._value = value.value;
  526. break;
  527. case 'string':
  528. this._value = value.value;
  529. break;
  530. case 'boolean':
  531. this._value = value.value;
  532. break;
  533. case 'id':
  534. this._value = value.value;
  535. break;
  536. }
  537. }
  538. if (schema) {
  539. if (Object.prototype.hasOwnProperty.call(schema, 'type')) {
  540. this._type = schema.type;
  541. }
  542. switch (this._type) {
  543. case 'boolean':
  544. if (this._value == 'False') {
  545. this._value = false;
  546. }
  547. else if (this._value == 'True') {
  548. this._value = true;
  549. }
  550. break;
  551. case 'int32':
  552. case 'int64':
  553. if (typeof this._value !== 'number') {
  554. if (typeof this._value === 'string') {
  555. this._value = parseInt(this._value, 10);
  556. }
  557. }
  558. break;
  559. case 'float32':
  560. case 'float64':
  561. if (typeof this._value !== 'number') {
  562. if (typeof this._value === 'string') {
  563. this._value = parseFloat(this._value);
  564. }
  565. }
  566. break;
  567. case 'int32[]':
  568. case 'int64[]': {
  569. switch (this._value.type) {
  570. case 'list':
  571. this._value = this._value.value.map((item) => {
  572. if (item.type === 'number') {
  573. const number = parseInt(item.value, 10);
  574. if (!Number.isNaN(item.value - number)) {
  575. return number;
  576. }
  577. }
  578. return item;
  579. });
  580. break;
  581. }
  582. break;
  583. }
  584. }
  585. if (Object.prototype.hasOwnProperty.call(schema, 'visible') && !schema.visible) {
  586. this._visible = false;
  587. }
  588. else if (Object.prototype.hasOwnProperty.call(schema, 'default')) {
  589. if (JSON.stringify(schema.default) == JSON.stringify(this._value)) {
  590. this._visible = false;
  591. }
  592. else if (Array.isArray(this._value) &&
  593. !Array.isArray(schema.default) &&
  594. this.value.every((item) => item == schema.default)) {
  595. this._visible = false;
  596. }
  597. }
  598. }
  599. if (Array.isArray(value) && value.length > 0 && value.every((obj) => obj && obj.__module__ && obj.__module__.startsWith('torch.nn'))) {
  600. this._value = '?';
  601. }
  602. }
  603. get type() {
  604. return this._type;
  605. }
  606. get name() {
  607. return this._name;
  608. }
  609. get value() {
  610. return this._value;
  611. }
  612. get visible() {
  613. return this._visible == false ? false : true;
  614. }
  615. };
  616. pytorch.Tensor = class {
  617. constructor(name, tensor, littleEndian) {
  618. this._name = name || '';
  619. this._type = new pytorch.TensorType(tensor.storage.dataType, new pytorch.TensorShape(tensor.size));
  620. this._data = tensor.storage.data;
  621. this._littleEndian = littleEndian;
  622. }
  623. get kind() {
  624. return 'Tensor';
  625. }
  626. get name() {
  627. return this._name;
  628. }
  629. get type() {
  630. return this._type;
  631. }
  632. get state() {
  633. return this._context().state;
  634. }
  635. get value() {
  636. const context = this._context();
  637. if (context.state) {
  638. return null;
  639. }
  640. context.limit = Number.MAX_SAFE_INTEGER;
  641. return this._decode(context, 0);
  642. }
  643. toString() {
  644. const context = this._context();
  645. if (context.state) {
  646. return '';
  647. }
  648. context.limit = 10000;
  649. const value = this._decode(context, 0);
  650. return pytorch.Tensor._stringify(value, '', ' ');
  651. }
  652. _context() {
  653. const context = {};
  654. context.state = null;
  655. context.index = 0;
  656. context.count = 0;
  657. if (!this._type.dataType) {
  658. context.state = 'Tensor has no data type.';
  659. return context;
  660. }
  661. switch (this._type.dataType) {
  662. case 'uint8':
  663. case 'qint8':
  664. case 'int8':
  665. case 'int16':
  666. case 'int32':
  667. case 'int64':
  668. case 'float16':
  669. case 'float32':
  670. case 'float64':
  671. break;
  672. default:
  673. context.state = "Tensor data type '" + this._type.dataType + "' is not supported.";
  674. return context;
  675. }
  676. if (!this._type.shape) {
  677. context.state = 'Tensor has no dimensions.';
  678. return context;
  679. }
  680. if (!this._data) {
  681. context.state = 'Tensor data is empty.';
  682. return context;
  683. }
  684. context.data = this._data;
  685. context.dataType = this._type.dataType;
  686. context.dimensions = this._type.shape.dimensions;
  687. context.dataView = new DataView(context.data.buffer, context.data.byteOffset, context.data.byteLength);
  688. return context;
  689. }
  690. _decode(context, dimension) {
  691. const results = [];
  692. const dimensions = (context.dimensions.length == 0) ? [ 1 ] : context.dimensions;
  693. const size = dimensions[dimension];
  694. if (dimension == dimensions.length - 1) {
  695. for (let i = 0; i < size; i++) {
  696. if (context.count > context.limit) {
  697. results.push('...');
  698. return results;
  699. }
  700. switch (context.dataType) {
  701. case 'uint8':
  702. results.push(context.dataView.getUint8(context.index, this._littleEndian));
  703. context.index++;
  704. context.count++;
  705. break;
  706. case 'qint8':
  707. case 'int8':
  708. results.push(context.dataView.getInt8(context.index, this._littleEndian));
  709. context.index++;
  710. context.count++;
  711. break;
  712. case 'int16':
  713. results.push(context.dataView.getInt16(context.index, this._littleEndian));
  714. context.index += 2;
  715. context.count++;
  716. break;
  717. case 'int32':
  718. results.push(context.dataView.getInt32(context.index, this._littleEndian));
  719. context.index += 4;
  720. context.count++;
  721. break;
  722. case 'int64':
  723. results.push(context.dataView.getInt64(context.index, this._littleEndian));
  724. context.index += 8;
  725. context.count++;
  726. break;
  727. case 'float16':
  728. results.push(context.dataView.getFloat16(context.index, this._littleEndian));
  729. context.index += 2;
  730. context.count++;
  731. break;
  732. case 'float32':
  733. results.push(context.dataView.getFloat32(context.index, this._littleEndian));
  734. context.index += 4;
  735. context.count++;
  736. break;
  737. case 'float64':
  738. results.push(context.dataView.getFloat64(context.index, this._littleEndian));
  739. context.index += 8;
  740. context.count++;
  741. break;
  742. }
  743. }
  744. }
  745. else {
  746. for (let j = 0; j < size; j++) {
  747. if (context.count > context.limit) {
  748. results.push('...');
  749. return results;
  750. }
  751. results.push(this._decode(context, dimension + 1));
  752. }
  753. }
  754. if (context.dimensions.length == 0) {
  755. return results[0];
  756. }
  757. return results;
  758. }
  759. static _stringify(value, indentation, indent) {
  760. if (Array.isArray(value)) {
  761. const result = [];
  762. result.push(indentation + '[');
  763. const items = value.map((item) => pytorch.Tensor._stringify(item, indentation + indent, indent));
  764. if (items.length > 0) {
  765. result.push(items.join(',\n'));
  766. }
  767. result.push(indentation + ']');
  768. return result.join('\n');
  769. }
  770. if (value && (value instanceof base.Int64 || value instanceof base.Uint64)) {
  771. return indentation + value.toString();
  772. }
  773. if (typeof value == 'string') {
  774. return indentation + value;
  775. }
  776. if (value == Infinity) {
  777. return indentation + 'Infinity';
  778. }
  779. if (value == -Infinity) {
  780. return indentation + '-Infinity';
  781. }
  782. if (isNaN(value)) {
  783. return indentation + 'NaN';
  784. }
  785. return indentation + value.toString();
  786. }
  787. };
  788. pytorch.TensorType = class {
  789. constructor(dataType, shape) {
  790. this._dataType = dataType;
  791. this._shape = shape;
  792. }
  793. get dataType() {
  794. return this._dataType;
  795. }
  796. get shape() {
  797. return this._shape;
  798. }
  799. toString() {
  800. return this._dataType + this._shape.toString();
  801. }
  802. };
  803. pytorch.TensorShape = class {
  804. constructor(dimensions) {
  805. this._dimensions = dimensions || [];
  806. }
  807. get dimensions() {
  808. return this._dimensions;
  809. }
  810. toString() {
  811. if (this._dimensions && this._dimensions.length > 0) {
  812. return '[' + this._dimensions.map((dimension) => dimension.toString()).join(',') + ']';
  813. }
  814. return '';
  815. }
  816. };
  817. pytorch.Metadata = class {
  818. static open(host) {
  819. if (pytorch.Metadata._metadata) {
  820. return Promise.resolve(pytorch.Metadata._metadata);
  821. }
  822. else {
  823. return host.request(null, 'pytorch-metadata.json', 'utf-8').then((data) => {
  824. pytorch.Metadata._metadata = new pytorch.Metadata(data);
  825. return pytorch.Metadata._metadata;
  826. }).catch(() => {
  827. pytorch.Metadata._metadata = new pytorch.Metadata(null);
  828. return pytorch.Metadata._metadata;
  829. });
  830. }
  831. }
  832. constructor(data) {
  833. this._map = new Map();
  834. this._attributeCache = new Map();
  835. if (data) {
  836. const items = JSON.parse(data);
  837. if (items) {
  838. for (const item of items) {
  839. if (item.name && item.schema) {
  840. item.schema.name = item.name;
  841. this._map.set(item.name, item.schema);
  842. }
  843. const index = item.name.indexOf(':');
  844. if (index !== -1) {
  845. const name = item.name.substring(0, index);
  846. if (!this._map.has(name)) {
  847. this._map.set(name, []);
  848. }
  849. this._map.get(name).push(item.name);
  850. }
  851. }
  852. }
  853. }
  854. }
  855. type(name) {
  856. const schema = this._map.get(name);
  857. if (schema) {
  858. return Array.isArray(schema) ? schema.map((name) => this._map.get(name)) : schema;
  859. }
  860. return null;
  861. }
  862. attribute(type, name) {
  863. const attributeName = type + ':' + name;
  864. if (!this._attributeCache.has(attributeName)) {
  865. this._attributeCache.set(attributeName, null);
  866. const schema = this.type(type);
  867. if (schema) {
  868. if (schema.inputs) {
  869. for (const input of schema.inputs) {
  870. this._attributeCache.set(type + ':' + input.name, input);
  871. }
  872. }
  873. if (schema.attributes) {
  874. for (const attribute of schema.attributes) {
  875. this._attributeCache.set(type + ':' + attribute.name, attribute);
  876. }
  877. }
  878. }
  879. }
  880. return this._attributeCache.get(attributeName);
  881. }
  882. };
  883. pytorch.Error = class extends Error {
  884. constructor(message) {
  885. super(message);
  886. this.name = 'Error loading PyTorch model.';
  887. }
  888. };
  889. pytorch.Execution = class {
  890. constructor(python, sources, exceptionCallback) {
  891. const self = this;
  892. this._python = python;
  893. this._sources = sources;
  894. this._exceptionCallback = exceptionCallback;
  895. this._utf8Decoder = new TextDecoder('utf-8');
  896. this._unknownNameMap = new Set();
  897. this._knownPackageMap = new Set([ 'torch', 'torchvision', 'collections', '__builtin__', '_codecs', 'argparse', 'numpy' ]);
  898. this._packages = new Map();
  899. this._context = new pytorch.Execution.Context();
  900. this._context.scope.builtins = {};
  901. this._context.scope.builtins.type = { __module__: 'builtins', __name__: 'type' };
  902. this._context.scope.builtins.module = { __module__: 'builtins', __name__: 'module', __class__: this._context.scope.builtins.type };
  903. this._context.scope.builtins.function = { __module__: 'builtins', __name__: 'function', __class__: this._context.scope.builtins.type };
  904. this._context.scope.builtins.method = { __module__: 'builtins', __name__: 'method', __class__: this._context.scope.builtins.type };
  905. this._context.scope.builtins.dict = { __module__: 'builtins', __name__: 'dict', __class__: this._context.scope.builtins.type };
  906. this._context.scope.builtins.list = { __module__: 'builtins', __name__: 'list', __class__: this._context.scope.builtins.type };
  907. this._context.scope.builtins.str = { __module__: 'builtins', __name__: 'str', __class__: this._context.scope.builtins.type };
  908. this._context.scope.builtins.tuple = { __module__: 'builtins', __name__: 'tuple', __class__: this._context.scope.builtins.type };
  909. this._context.scope.typing = { __name__: 'typing', __class__: this._context.scope.builtins.module };
  910. this._context.scope.typing._GenericAlias = { __module__: 'typing', __name__: '_GenericAlias', __class__: this._context.scope.builtins.type };
  911. this._context.scope.typing._SpecialForm = { __module__: 'typing', __name__: '_SpecialForm', __class__: this._context.scope.builtins.type };
  912. this._context.scope.typing._VariadicGenericAlias = { __module__: 'typing', __name__: '_VariadicGenericAlias', __class__: this._context.scope.builtins.type };
  913. this._context.scope.typing.Dict = { __module__: 'typing', __name__: 'Dict', __class__: this._context.scope.typing._VariadicGenericAlias, __origin__: this._context.scope.builtins.dict };
  914. this._context.scope.typing.List = { __module__: 'typing', __name__: 'List', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.list };
  915. this._context.scope.typing.Optional = { __module__: 'typing', __class__: this._context.scope.typing._SpecialForm };
  916. this._context.scope.typing.Tuple = { __module__: 'typing', __name__: 'Tuple', __class__: this._context.scope.typing._GenericAlias, __origin__: this._context.scope.builtins.tuple };
  917. this._context.scope.torch = { __name__: 'torch', __class__: this._context.scope.builtins.module };
  918. this._context.scope.torch.Tensor = { __module__: 'torch', __name__: 'Tensor', __class__: this._context.scope.builtins.type };
  919. this._registerConstructor('argparse.Namespace', function (args) {
  920. this.args = args;
  921. });
  922. this._registerConstructor('torch.autograd.variable.Variable', function() {});
  923. this._registerConstructor('torch.backends.cudnn.rnn.Unserializable', function() {});
  924. this._registerConstructor('torch.device', function(type, index) {
  925. this.type = type;
  926. if (index) {
  927. this.index = index;
  928. }
  929. });
  930. this._registerConstructor('torch.distributions.multivariate_normal.MultivariateNormal', function() {});
  931. this._registerConstructor('torch.distributions.transforms.LowerCholeskyTransform', function() {});
  932. this._registerConstructor('torch.nn.backends.thnn._get_thnn_function_backend', function() {});
  933. this._registerConstructor('torch.nn.intrinsic.modules.fused.ConvReLU2d', function() {});
  934. this._registerConstructor('torch.nn.modules.activation.CELU', function() {});
  935. this._registerConstructor('torch.nn.modules.activation.ELU', function() {});
  936. this._registerConstructor('torch.nn.modules.activation.GELU', function() {});
  937. this._registerConstructor('torch.nn.modules.activation.GLU', function() {});
  938. this._registerConstructor('torch.nn.modules.activation.Hardtanh', function() {});
  939. this._registerConstructor('torch.nn.modules.activation.LeakyReLU', function() {});
  940. this._registerConstructor('torch.nn.modules.activation.LogSigmoid', function() {});
  941. this._registerConstructor('torch.nn.modules.activation.LogSoftmax', function() {});
  942. this._registerConstructor('torch.nn.modules.activation.MultiheadAttention', function() {});
  943. this._registerConstructor('torch.nn.modules.activation.ReLU', function() {});
  944. this._registerConstructor('torch.nn.modules.activation.ReLU6', function() {});
  945. this._registerConstructor('torch.nn.modules.activation.PReLU', function() {});
  946. this._registerConstructor('torch.nn.modules.activation.RReLU', function() {});
  947. this._registerConstructor('torch.nn.modules.activation.SELU', function() {});
  948. this._registerConstructor('torch.nn.modules.activation.Sigmoid', function() {});
  949. this._registerConstructor('torch.nn.modules.activation.Softmax', function() {});
  950. this._registerConstructor('torch.nn.modules.activation.Softmax2d', function() {});
  951. this._registerConstructor('torch.nn.modules.activation.Softplus', function() {});
  952. this._registerConstructor('torch.nn.modules.activation.Tanh', function() {});
  953. this._registerConstructor('torch.nn.modules.activation.Threshold', function() {});
  954. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm1d', function() {});
  955. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm2d', function() {});
  956. this._registerConstructor('torch.nn.modules.batchnorm.BatchNorm3d', function() {});
  957. this._registerConstructor('torch.nn.modules.batchnorm.SyncBatchNorm', function() {});
  958. this._registerConstructor('torch.nn.modules.container.ModuleDict', function() {});
  959. this._registerConstructor('torch.nn.modules.container.ModuleList', function() {});
  960. this._registerConstructor('torch.nn.modules.container.ParameterList', function() {});
  961. this._registerConstructor('torch.nn.modules.container.Sequential', function() {});
  962. this._registerConstructor('torch.nn.modules.conv.Conv1d', function() {});
  963. this._registerConstructor('torch.nn.modules.conv.Conv2d', function() {});
  964. this._registerConstructor('torch.nn.modules.conv.Conv3d', function() {});
  965. this._registerConstructor('torch.nn.modules.conv.ConvTranspose1d', function() {});
  966. this._registerConstructor('torch.nn.modules.conv.ConvTranspose2d', function() {});
  967. this._registerConstructor('torch.nn.modules.conv.ConvTranspose3d', function() {});
  968. this._registerConstructor('torch.nn.modules.distance.CosineSimilarity', function() {});
  969. this._registerConstructor('torch.nn.modules.dropout.Dropout', function() {});
  970. this._registerConstructor('torch.nn.modules.dropout.Dropout2d', function() {});
  971. this._registerConstructor('torch.nn.modules.dropout.Dropout3d', function() {});
  972. this._registerConstructor('torch.nn.modules.fold.Unfold', function() {});
  973. this._registerConstructor('torch.nn.modules.flatten.Flatten', function() {});
  974. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm1d', function() {});
  975. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm2d', function() {});
  976. this._registerConstructor('torch.nn.modules.instancenorm.InstanceNorm3d', function() {});
  977. this._registerConstructor('torch.nn.modules.linear._LinearWithBias', function() {});
  978. this._registerConstructor('torch.nn.modules.linear.Linear', function() {});
  979. this._registerConstructor('torch.nn.modules.linear.Identity', function() {});
  980. this._registerConstructor('torch.nn.modules.loss.BCELoss', function() {});
  981. this._registerConstructor('torch.nn.modules.loss.BCEWithLogitsLoss', function() {});
  982. this._registerConstructor('torch.nn.modules.loss.CrossEntropyLoss', function() {});
  983. this._registerConstructor('torch.nn.modules.loss.L1Loss', function() {});
  984. this._registerConstructor('torch.nn.modules.loss.MSELoss', function() {});
  985. this._registerConstructor('torch.nn.modules.loss.NLLLoss', function() {});
  986. this._registerConstructor('torch.nn.modules.loss.SmoothL1Loss', function() {});
  987. this._registerConstructor('torch.nn.modules.normalization.CrossMapLRN2d', function() {});
  988. this._registerConstructor('torch.nn.modules.normalization.GroupNorm', function() {});
  989. this._registerConstructor('torch.nn.modules.normalization.LayerNorm', function() {});
  990. this._registerConstructor('torch.nn.modules.normalization.LocalResponseNorm', function() {});
  991. this._registerConstructor('torch.nn.modules.padding.ReflectionPad1d', function() {});
  992. this._registerConstructor('torch.nn.modules.padding.ReflectionPad2d', function() {});
  993. this._registerConstructor('torch.nn.modules.padding.ReplicationPad1d', function() {});
  994. this._registerConstructor('torch.nn.modules.padding.ReplicationPad2d', function() {});
  995. this._registerConstructor('torch.nn.modules.padding.ReplicationPad3d', function() {});
  996. this._registerConstructor('torch.nn.modules.padding.ZeroPad2d', function() {});
  997. this._registerConstructor('torch.nn.modules.padding.ConstantPad1d', function() {});
  998. this._registerConstructor('torch.nn.modules.padding.ConstantPad2d', function() {});
  999. this._registerConstructor('torch.nn.modules.padding.ConstantPad3d', function() {});
  1000. this._registerConstructor('torch.nn.modules.pixelshuffle.PixelShuffle', function() {});
  1001. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool1d', function() {});
  1002. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool2d', function() {});
  1003. this._registerConstructor('torch.nn.modules.pooling.AdaptiveAvgPool3d', function() {});
  1004. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool1d', function() {});
  1005. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool2d', function() {});
  1006. this._registerConstructor('torch.nn.modules.pooling.AdaptiveMaxPool3d', function() {});
  1007. this._registerConstructor('torch.nn.modules.pooling.AvgPool1d', function() {});
  1008. this._registerConstructor('torch.nn.modules.pooling.AvgPool2d', function() {});
  1009. this._registerConstructor('torch.nn.modules.pooling.AvgPool3d', function() {});
  1010. this._registerConstructor('torch.nn.modules.pooling.FractionalMaxPool2d', function() {});
  1011. this._registerConstructor('torch.nn.modules.pooling.MaxPool1d', function() {});
  1012. this._registerConstructor('torch.nn.modules.pooling.MaxPool2d', function() {});
  1013. this._registerConstructor('torch.nn.modules.pooling.MaxPool3d', function() {});
  1014. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool1d', function() {});
  1015. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool2d', function() {});
  1016. this._registerConstructor('torch.nn.modules.pooling.MaxUnpool3d', function() {});
  1017. this._registerConstructor('torch.nn.modules.rnn.GRU', function() {});
  1018. this._registerConstructor('torch.nn.modules.rnn.GRUCell', function() {});
  1019. this._registerConstructor('torch.nn.modules.rnn.LSTM', function() {});
  1020. this._registerConstructor('torch.nn.modules.rnn.LSTMCell', function() {});
  1021. this._registerConstructor('torch.nn.modules.rnn.RNN', function() {});
  1022. this._registerConstructor('torch.nn.modules.sparse.Embedding', function() {});
  1023. this._registerConstructor('torch.nn.modules.sparse.EmbeddingBag', function() {});
  1024. this._registerConstructor('torch.nn.modules.transformer.Transformer', function() {});
  1025. this._registerConstructor('torch.nn.modules.transformer.TransformerDecoder', function() {});
  1026. this._registerConstructor('torch.nn.modules.transformer.TransformerDecoderLayer', function() {});
  1027. this._registerConstructor('torch.nn.modules.transformer.TransformerEncoder', function() {});
  1028. this._registerConstructor('torch.nn.modules.transformer.TransformerEncoderLayer', function() {});
  1029. this._registerConstructor('torch.nn.modules.upsampling.Upsample', function() {});
  1030. this._registerConstructor('torch.nn.modules.upsampling.UpsamplingBilinear2d', function() {});
  1031. this._registerConstructor('torch.nn.modules.upsampling.UpsamplingNearest2d', function() {});
  1032. this._registerConstructor('torch.nn.parallel.data_parallel.DataParallel', function() {});
  1033. this._registerConstructor('torch.nn.parallel.distributed.DistributedDataParallel', function() {});
  1034. this._registerConstructor('torch.nn.parameter.Parameter', function(data, requires_grad) {
  1035. this.data = data;
  1036. this.requires_grad = requires_grad;
  1037. });
  1038. this._registerConstructor('torch.nn.quantized.modules.functional_modules.FloatFunctional', function() {});
  1039. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNorm', function() {});
  1040. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNormStateDictHook', function() {});
  1041. this._registerConstructor('torch.nn.utils.spectral_norm.SpectralNormLoadStateDictPreHook', function() {});
  1042. this._registerConstructor('torch.nn.utils.weight_norm.WeightNorm', function() {});
  1043. this._registerConstructor('torch.optim.adam.Adam', function() {});
  1044. this._registerConstructor('torch.optim.adagrad.Adagrad', function() {});
  1045. this._registerConstructor('torch.optim.lr_scheduler.MultiStepLR', function() {});
  1046. this._registerConstructor('torch.optim.lr_scheduler.StepLR', function() {});
  1047. this._registerConstructor('torch.optim.rmsprop.RMSprop', function() {});
  1048. this._registerConstructor('torch.optim.sgd.SGD', function() {});
  1049. this._registerConstructor('torch.quantization.stubs.DeQuantStub', function() {});
  1050. this._registerConstructor('torch.quantization.stubs.QuantStub', function() {});
  1051. this._registerConstructor('torchvision.datasets.folder.ImageFolder', function() {});
  1052. this._registerConstructor('torchvision.models.alexnet.AlexNet', function() {});
  1053. this._registerConstructor('torchvision.models.densenet.DenseNet', function() {});
  1054. this._registerConstructor('torchvision.models.densenet._DenseBlock', function() {});
  1055. this._registerConstructor('torchvision.models.densenet._DenseLayer', function() {});
  1056. this._registerConstructor('torchvision.models.densenet._Transition', function() {});
  1057. this._registerConstructor('torchvision.models.detection._utils.BalancedPositiveNegativeSampler', function() {});
  1058. this._registerConstructor('torchvision.models.detection._utils.BoxCoder', function() {});
  1059. this._registerConstructor('torchvision.models.detection._utils.Matcher', function() {});
  1060. this._registerConstructor('torchvision.models.detection.backbone_utils.BackboneWithFPN', function() {});
  1061. this._registerConstructor('torchvision.models.detection.faster_rcnn.FasterRCNN', function() {});
  1062. this._registerConstructor('torchvision.models.detection.faster_rcnn.FastRCNNPredictor', function() {});
  1063. this._registerConstructor('torchvision.models.detection.faster_rcnn.TwoMLPHead', function() {});
  1064. this._registerConstructor('torchvision.models.detection.keypoint_rcnn.KeypointRCNN', function() {});
  1065. this._registerConstructor('torchvision.models.detection.keypoint_rcnn.KeypointRCNNHeads', function() {});
  1066. this._registerConstructor('torchvision.models.detection.keypoint_rcnn.KeypointRCNNPredictor', function() {});
  1067. this._registerConstructor('torchvision.models.detection.mask_rcnn.MaskRCNN', function() {});
  1068. this._registerConstructor('torchvision.models.detection.mask_rcnn.MaskRCNNHeads', function() {});
  1069. this._registerConstructor('torchvision.models.detection.mask_rcnn.MaskRCNNPredictor', function() {});
  1070. this._registerConstructor('torchvision.models.detection.roi_heads.RoIHeads', function() {});
  1071. this._registerConstructor('torchvision.models.detection.rpn.AnchorGenerator', function() {});
  1072. this._registerConstructor('torchvision.models.detection.rpn.RegionProposalNetwork', function() {});
  1073. this._registerConstructor('torchvision.models.detection.rpn.RPNHead', function() {});
  1074. this._registerConstructor('torchvision.models.detection.transform.GeneralizedRCNNTransform', function() {});
  1075. this._registerConstructor('torchvision.models.googlenet.BasicConv2d', function() {});
  1076. this._registerConstructor('torchvision.models.googlenet.GoogLeNet', function() {});
  1077. this._registerConstructor('torchvision.models.googlenet.Inception', function() {});
  1078. this._registerConstructor('torchvision.models.inception.BasicConv2d', function() {});
  1079. this._registerConstructor('torchvision.models.inception.Inception3', function() {});
  1080. this._registerConstructor('torchvision.models.inception.InceptionAux', function() {});
  1081. this._registerConstructor('torchvision.models.inception.InceptionA', function() {});
  1082. this._registerConstructor('torchvision.models.inception.InceptionB', function() {});
  1083. this._registerConstructor('torchvision.models.inception.InceptionC', function() {});
  1084. this._registerConstructor('torchvision.models.inception.InceptionD', function() {});
  1085. this._registerConstructor('torchvision.models.inception.InceptionE', function() {});
  1086. this._registerConstructor('torchvision.models.mobilenet.ConvBNReLU', function() {});
  1087. this._registerConstructor('torchvision.models.mobilenet.MobileNetV2', function() {});
  1088. this._registerConstructor('torchvision.models.mobilenet.InvertedResidual', function() {});
  1089. this._registerConstructor('torchvision.models.resnet.Bottleneck', function() {});
  1090. this._registerConstructor('torchvision.models.resnet.BasicBlock', function() {});
  1091. this._registerConstructor('torchvision.models.quantization.resnet.QuantizableBottleneck', function() {});
  1092. this._registerConstructor('torchvision.models.quantization.resnet.QuantizableResNet', function() {});
  1093. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPP', function() {});
  1094. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPPConv', function() {});
  1095. this._registerConstructor('torchvision.models.segmentation.deeplabv3.ASPPPooling', function() {});
  1096. this._registerConstructor('torchvision.models.segmentation.deeplabv3.DeepLabHead', function() {});
  1097. this._registerConstructor('torchvision.models.segmentation.deeplabv3.DeepLabV3', function() {});
  1098. this._registerConstructor('torchvision.models.segmentation.fcn.FCN', function() {});
  1099. this._registerConstructor('torchvision.models.segmentation.fcn.FCNHead', function() {});
  1100. this._registerConstructor('torchvision.models.shufflenetv2.ShuffleNetV2', function() {});
  1101. this._registerConstructor('torchvision.models.shufflenetv2.InvertedResidual', function() {});
  1102. this._registerConstructor('torchvision.models.squeezenet.Fire', function() {});
  1103. this._registerConstructor('torchvision.models.squeezenet.SqueezeNet', function() {});
  1104. this._registerConstructor('torchvision.models.resnet.ResNet', function() {});
  1105. this._registerConstructor('torchvision.models.vgg.VGG', function() {});
  1106. this._registerConstructor('torchvision.models.video.resnet.BasicBlock', function() {});
  1107. this._registerConstructor('torchvision.models.video.resnet.BasicStem', function() {});
  1108. this._registerConstructor('torchvision.models.video.resnet.Conv3DNoTemporal', function() {});
  1109. this._registerConstructor('torchvision.models.video.resnet.Conv3DSimple', function() {});
  1110. this._registerConstructor('torchvision.models.video.resnet.VideoResNet', function() {});
  1111. this._registerConstructor('torchvision.models._utils.IntermediateLayerGetter', function() {});
  1112. this._registerConstructor('torchvision.ops.feature_pyramid_network.FeaturePyramidNetwork', function() {});
  1113. this._registerConstructor('torchvision.ops.feature_pyramid_network.LastLevelMaxPool', function() {});
  1114. this._registerConstructor('torchvision.ops.misc.ConvTranspose2d', function() {});
  1115. this._registerConstructor('torchvision.ops.misc.FrozenBatchNorm2d', function() {});
  1116. this._registerConstructor('torchvision.ops.poolers.LevelMapper', function() {});
  1117. this._registerConstructor('torchvision.ops.poolers.MultiScaleRoIAlign', function() {});
  1118. this._registerConstructor('torchvision.transforms.transforms.Compose', function() {});
  1119. this._registerConstructor('torchvision.transforms.transforms.Normalize', function() {});
  1120. this._registerConstructor('torchvision.transforms.transforms.Resize', function() {});
  1121. this._registerConstructor('torchvision.transforms.transforms.ToTensor', function() {});
  1122. this._registerConstructor('torch.ByteStorage', function (size) {
  1123. this.size = size; this.dataTypeSize = 1; this.dataType = 'uint8';
  1124. });
  1125. this._registerConstructor('torch.CharStorage', function (size) {
  1126. this.size = size; this.dataTypeSize = 1; this.dataType = 'int8';
  1127. });
  1128. this._registerConstructor('torch.ShortStorage', function (size) {
  1129. this.size = size; this.dataTypeSize = 2; this.dataType = 'int16';
  1130. });
  1131. this._registerConstructor('torch.IntStorage', function (size) {
  1132. this.size = size; this.dataTypeSize = 4; this.dataType = 'int32';
  1133. });
  1134. this._registerConstructor('torch.LongStorage', function (size) {
  1135. this.size = size; this.dataTypeSize = 8; this.dataType = 'int64';
  1136. });
  1137. this._registerConstructor('torch.HalfStorage', function (size) {
  1138. this.size = size; this.dataTypeSize = 2; this.dataType = 'float16';
  1139. });
  1140. this._registerConstructor('torch.FloatStorage', function (size) {
  1141. this.size = size; this.dataTypeSize = 4; this.dataType = 'float32';
  1142. });
  1143. this._registerConstructor('torch.DoubleStorage', function (size) {
  1144. this.size = size; this.dataTypeSize = 8; this.dataType = 'float64';
  1145. });
  1146. this._registerConstructor('torch.QInt8Storage', function (size) {
  1147. this.size = size; this.dataTypeSize = 1; this.dataType = 'qint8';
  1148. });
  1149. this._registerConstructor('torch.FloatTensor', function () {
  1150. this.__setstate__ = function(state) {
  1151. this.storage = state[0];
  1152. this.storage_offset = state[1];
  1153. this.size = state[2];
  1154. this.stride = state[3];
  1155. };
  1156. });
  1157. this._registerConstructor('torch.DoubleTensor', function () {
  1158. this.__setstate__ = function(state) {
  1159. this.storage = state[0];
  1160. this.storage_offset = state[1];
  1161. this.size = state[2];
  1162. this.stride = state[3];
  1163. };
  1164. });
  1165. this._registerConstructor('torch.cuda.FloatTensor', function () {
  1166. this.__setstate__ = function(state) {
  1167. this.storage = state[0];
  1168. this.storage_offset = state[1];
  1169. this.size = state[2];
  1170. this.stride = state[3];
  1171. };
  1172. });
  1173. this._registerConstructor('torch.cuda.DoubleTensor', function () {
  1174. this.__setstate__ = function(state) {
  1175. this.storage = state[0];
  1176. this.storage_offset = state[1];
  1177. this.size = state[2];
  1178. this.stride = state[3];
  1179. };
  1180. });
  1181. this._registerConstructor('numpy.dtype', function(obj, align, copy) {
  1182. switch (obj) {
  1183. case 'i1': this.name = 'int8'; this.itemsize = 1; break;
  1184. case 'i2': this.name = 'int16'; this.itemsize = 2; break;
  1185. case 'i4': this.name = 'int32'; this.itemsize = 4; break;
  1186. case 'i8': this.name = 'int64'; this.itemsize = 8; break;
  1187. case 'b1': this.name = 'uint8'; this.itemsize = 1; break;
  1188. case 'u1': this.name = 'uint8'; this.itemsize = 1; break;
  1189. case 'u2': this.name = 'uint16'; this.itemsize = 2; break;
  1190. case 'u4': this.name = 'uint32'; this.itemsize = 4; break;
  1191. case 'u8': this.name = 'uint64'; this.itemsize = 8; break;
  1192. case 'f4': this.name = 'float32'; this.itemsize = 4; break;
  1193. case 'f8': this.name = 'float64'; this.itemsize = 8; break;
  1194. default:
  1195. if (obj.startsWith('V')) {
  1196. this.itemsize = Number(obj.substring(1));
  1197. this.name = 'void' + (this.itemsize * 8).toString();
  1198. }
  1199. else if (obj.startsWith('O')) {
  1200. this.itemsize = Number(obj.substring(1));
  1201. this.name = 'object';
  1202. }
  1203. else if (obj.startsWith('S')) {
  1204. this.itemsize = Number(obj.substring(1));
  1205. this.name = 'string';
  1206. }
  1207. else if (obj.startsWith('U')) {
  1208. this.itemsize = Number(obj.substring(1));
  1209. this.name = 'string';
  1210. }
  1211. else if (obj.startsWith('M')) {
  1212. this.itemsize = Number(obj.substring(1));
  1213. this.name = 'datetime';
  1214. }
  1215. else {
  1216. throw new pytorch.Error("Unknown dtype '" + obj.toString() + "'.");
  1217. }
  1218. break;
  1219. }
  1220. this.align = align;
  1221. this.copy = copy;
  1222. this.__setstate__ = function(state) {
  1223. switch (state.length) {
  1224. case 8:
  1225. this.version = state[0];
  1226. this.byteorder = state[1];
  1227. this.subarray = state[2];
  1228. this.names = state[3];
  1229. this.fields = state[4];
  1230. this.elsize = state[5];
  1231. this.alignment = state[6];
  1232. this.int_dtypeflags = state[7];
  1233. break;
  1234. default:
  1235. throw new pytorch.Error("Unknown numpy.dtype setstate length '" + state.length.toString() + "'.");
  1236. }
  1237. };
  1238. });
  1239. this._registerConstructor('numpy.core.multiarray._reconstruct', function(subtype, shape, dtype) {
  1240. this.subtype = subtype;
  1241. this.shape = shape;
  1242. this.dtype = dtype;
  1243. this.__setstate__ = function(state) {
  1244. this.version = state[0];
  1245. this.shape = state[1];
  1246. this.typecode = state[2];
  1247. this.is_f_order = state[3];
  1248. this.rawdata = state[4];
  1249. };
  1250. this.__read__ = function(unpickler) {
  1251. const array = {};
  1252. const subtype = this.subtype.split('.');
  1253. array.__name__ = subtype.pop();
  1254. array.__module__ = subtype.join('.');
  1255. array.dtype = this.typecode;
  1256. array.shape = this.shape;
  1257. let size = array.dtype.itemsize;
  1258. for (let i = 0; i < array.shape.length; i++) {
  1259. size = size * array.shape[i];
  1260. }
  1261. if (typeof this.rawdata == 'string') {
  1262. array.data = unpickler.unescape(this.rawdata, size);
  1263. if (array.data.length != size) {
  1264. throw new pytorch.Error('Invalid string array data size.');
  1265. }
  1266. }
  1267. else {
  1268. array.data = this.rawdata;
  1269. if (array.data.length != size) {
  1270. // throw new pytorch.Error('Invalid array data size.');
  1271. }
  1272. }
  1273. return array;
  1274. };
  1275. });
  1276. this._registerFunction('__builtin__.bytearray', function(source, encoding /*, errors */) {
  1277. if (encoding === 'latin-1') {
  1278. const array = new Uint8Array(source.length);
  1279. for (let i = 0; i < source.length; i++) {
  1280. array[i] = source.charCodeAt(i);
  1281. }
  1282. return array;
  1283. }
  1284. throw new pytorch.Error("Unsupported bytearray encoding '" + JSON.stringify(encoding) + "'.");
  1285. });
  1286. this._registerFunction('__builtin__.getattr', function(obj, name, defaultValue) {
  1287. if (Object.prototype.hasOwnProperty.call(obj, name)) {
  1288. return obj[name];
  1289. }
  1290. return defaultValue;
  1291. });
  1292. this._registerFunction('__builtin__.set', function(iterable) {
  1293. return iterable ? iterable : [];
  1294. });
  1295. this._registerFunction('__builtin__.slice', function(start, stop , step) {
  1296. return [ start, stop, step ];
  1297. });
  1298. this._registerFunction('collections.Counter', function(/* iterable */) {
  1299. return {};
  1300. });
  1301. this._registerFunction('collections.OrderedDict', function(args) {
  1302. const obj = new Map();
  1303. obj.__setitem__ = function(key, value) {
  1304. obj.set(key, value);
  1305. };
  1306. if (args) {
  1307. for (const arg of args) {
  1308. obj.__setitem__(arg[0], arg[1]);
  1309. }
  1310. }
  1311. return obj;
  1312. });
  1313. this._registerFunction('numpy.core.multiarray.scalar', function(dtype, rawData) {
  1314. let data = rawData;
  1315. if (rawData.constructor !== Uint8Array) {
  1316. data = new Uint8Array(rawData.length);
  1317. for (let i = 0; i < rawData.length; i++) {
  1318. data[i] = rawData.charCodeAt(i);
  1319. }
  1320. }
  1321. const dataView = new DataView(data.buffer, data.byteOffset, data.byteLength);
  1322. switch (dtype.name) {
  1323. case 'float32':
  1324. return dataView.getFloat32(0, true);
  1325. case 'float64':
  1326. return dataView.getFloat64(0, true);
  1327. case 'uint8':
  1328. return dataView.getUint8(0, true);
  1329. case 'int8':
  1330. return dataView.getInt8(0, true);
  1331. case 'int16':
  1332. return dataView.getInt16(0, true);
  1333. case 'int32':
  1334. return dataView.getInt32(0, true);
  1335. case 'int64':
  1336. return dataView.getInt64(0, true);
  1337. }
  1338. throw new pytorch.Error("Unknown scalar type '" + dtype.name + "'.");
  1339. });
  1340. this._registerFunction('_codecs.encode', function(obj /*, econding */) {
  1341. return obj;
  1342. });
  1343. this._registerFunction('collections.defaultdict', function(/* default_factory */) {
  1344. return {};
  1345. });
  1346. this._registerFunction('annotate', function(type, value) {
  1347. return value;
  1348. });
  1349. this._registerFunction('int', function(/* tensor */) {
  1350. return NaN; // TODO
  1351. });
  1352. this._registerFunction('float', function(/* tensor */) {
  1353. return NaN; // TODO
  1354. });
  1355. this._registerFunction('getattr', function(obj, name, defaultValue) {
  1356. if (Object.prototype.hasOwnProperty.call(obj, name)) {
  1357. return obj[name];
  1358. }
  1359. return defaultValue;
  1360. });
  1361. this._registerFunction('unchecked_cast', function(type, value) {
  1362. return value;
  1363. });
  1364. this._registerFunction('ops.prim.data', function(tensor) {
  1365. return tensor;
  1366. });
  1367. this._registerFunction('ops.prim.unchecked_unwrap_optional', function(value) {
  1368. return value;
  1369. });
  1370. this._registerFunction('ops.prim.NumToTensor', function(value) {
  1371. return { __module__: 'torch', __name__: 'Tensor', value: value }; // TODO
  1372. });
  1373. this._registerFunction('ops.prim.min', function(value) {
  1374. return Math.min.apply(null, value);
  1375. });
  1376. this._registerFunction('ops.prim.shape', function(value) {
  1377. return value.size;
  1378. });
  1379. this._registerFunction('ops.quantized.conv_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
  1380. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv_prepack' }; // TODO
  1381. });
  1382. this._registerFunction('ops.quantized.conv2d_prepack', function(/* weight, bias, stride, padding, dilation, groups */) {
  1383. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.conv2d_prepack' }; // TODO
  1384. });
  1385. this._registerFunction('ops.quantized.linear_prepack', function(/* weight, bias */) {
  1386. return { __module__: 'torch', __name__: 'Tensor', __origin__: 'ops.quantized.linear_prepack' }; // TODO
  1387. });
  1388. this._registerFunction('ops.prim.RaiseException', function(message) {
  1389. throw new pytorch.Error(message);
  1390. });
  1391. this._registerFunction('range', function(start, stop, step) {
  1392. if (start !== undefined && Number.isInteger(start) && stop === undefined && step === undefined) {
  1393. return Array(start).keys();
  1394. }
  1395. throw new pytorch.Error('Unsupported function range(' + JSON.stringify(start) + ', ' + JSON.stringify(stop) + ', ' + JSON.stringify(step) + ')');
  1396. });
  1397. this._registerFunction('torch._utils._rebuild_tensor', function (storage, storage_offset, size, stride) {
  1398. return {
  1399. __module__: storage.__module__,
  1400. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1401. storage: storage,
  1402. storage_offset: storage_offset,
  1403. size: size,
  1404. stride: stride
  1405. };
  1406. });
  1407. this._registerFunction('torch._utils._rebuild_tensor_v2', function (storage, storage_offset, size, stride, requires_grad, backward_hooks) {
  1408. return {
  1409. __module__: storage.__module__,
  1410. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1411. storage: storage,
  1412. storage_offset: storage_offset,
  1413. size: size,
  1414. stride: stride,
  1415. requires_grad: requires_grad,
  1416. backward_hooks: backward_hooks
  1417. };
  1418. });
  1419. this._registerFunction('torch._utils._rebuild_parameter', function(data, requires_grad, backward_hooks) {
  1420. const obj = self.invoke('torch.nn.parameter.Parameter', [ data, requires_grad ]);
  1421. obj.backward_hooks = backward_hooks;
  1422. return obj;
  1423. });
  1424. this._registerFunction('torch._utils._rebuild_qtensor', function(storage, storage_offset, size, stride, quantizer_params, requires_grad, backward_hooks) {
  1425. return {
  1426. __module__: storage.__module__,
  1427. __name__: storage.__name__.replace('Storage', 'Tensor'),
  1428. storage: storage,
  1429. storage_offset: storage_offset,
  1430. size: size,
  1431. stride: stride,
  1432. quantizer_params: quantizer_params,
  1433. requires_grad:requires_grad,
  1434. backward_hooks: backward_hooks
  1435. };
  1436. });
  1437. this._registerFunction('torch._set_item', function(dict, key, value) {
  1438. dict[key] = value;
  1439. });
  1440. this._registerFunction('torch.__contains__', function(dict, key) {
  1441. return dict[key] !== undefined;
  1442. });
  1443. this._registerFunction('torch.__derive_index', function(index, start, step) {
  1444. return start + index * step;
  1445. });
  1446. this._registerFunction('torch.__is__', function(left, right) {
  1447. if (left === null && right === null) {
  1448. return true;
  1449. }
  1450. if ((left !== null && right === null) || (left === null && right !== null)) {
  1451. return false;
  1452. }
  1453. throw new pytorch.Error("Unknown 'torch.__is__' expression type.");
  1454. });
  1455. this._registerFunction('torch.__isnot__', function(left, right) {
  1456. if (left === null && right === null) {
  1457. return false;
  1458. }
  1459. if ((left !== null && right === null) || (left === null && right !== null)) {
  1460. return true;
  1461. }
  1462. throw new pytorch.Error("Unknown 'torch.__isnot__' expression type.");
  1463. });
  1464. this._registerFunction('torch.__not__', function(value) {
  1465. if (typeof value === 'boolean') {
  1466. return !value;
  1467. }
  1468. throw new pytorch.Error("Unknown 'torch.__not__' expression type.");
  1469. });
  1470. this._registerFunction('torch.__range_length', function(lo, hi, step) {
  1471. if (step === 0) {
  1472. throw new pytorch.Error('range() arg 3 must not be zero');
  1473. }
  1474. if (step > 0 && lo < hi) {
  1475. return 1 + (hi - 1 - lo) / step;
  1476. }
  1477. else if (step < 0 && lo > hi) {
  1478. return 1 + (lo - 1 - hi) / (0 - step);
  1479. }
  1480. return 0;
  1481. });
  1482. this._registerFunction('torch._unwrap_optional', function(value) {
  1483. return value; // TODO
  1484. });
  1485. this._registerFunction('torch.add', function(left, right) {
  1486. if (typeof left === 'number' && typeof right === 'number') {
  1487. return left * right;
  1488. }
  1489. throw new pytorch.Error('Unknown torch.add expression type.');
  1490. });
  1491. this._registerFunction('torch.append', function(tensors, tensor) {
  1492. tensors.push(tensor);
  1493. return tensor;
  1494. });
  1495. this._registerFunction('torch.dict', function(args) {
  1496. if (args) {
  1497. throw new pytorch.Error("'torch.dict' arguments not supported.");
  1498. }
  1499. return {};
  1500. });
  1501. this._registerFunction('torch.dim', function(tensor) {
  1502. if (tensor && tensor.size) {
  1503. return tensor.size.length;
  1504. }
  1505. return 0; // TODO
  1506. });
  1507. this._registerFunction('torch.eq', function(left, right) {
  1508. if (typeof left === 'string' && typeof right === 'string') {
  1509. return left === right;
  1510. }
  1511. if (typeof left === 'number' && typeof right === 'number') {
  1512. return left === right;
  1513. }
  1514. throw new pytorch.Error("Unknown 'torch.eq' expression type.");
  1515. });
  1516. this._registerFunction('torch.floordiv', function(/* left, right */) {
  1517. return undefined;
  1518. });
  1519. this._registerFunction('torch.gt', function(left, right) {
  1520. if (typeof left === 'number' && typeof right === 'number') {
  1521. if (!isNaN(left) && !isNaN(right)) {
  1522. return left > right;
  1523. }
  1524. }
  1525. if (isNaN(left) && !isNaN(right)) {
  1526. return true;
  1527. }
  1528. throw new pytorch.Error("Unknown 'torch.gt' expression type.");
  1529. });
  1530. this._registerFunction('torch.jit._pickle.build_boollist', function(data) {
  1531. return data;
  1532. });
  1533. this._registerFunction('torch.jit._pickle.build_doublelist', function(data) {
  1534. return data;
  1535. });
  1536. this._registerFunction('torch.jit._pickle.build_intlist', function(data) {
  1537. return data;
  1538. });
  1539. this._registerFunction('torch.jit._pickle.build_tensorlist', function(data) {
  1540. return data;
  1541. });
  1542. this._registerFunction('torch.jit._pickle.build_tensor_from_id', function(data) {
  1543. return data;
  1544. });
  1545. this._registerFunction('torch.jit._pickle.restore_type_tag', function(value /*, type_str */) {
  1546. return value;
  1547. });
  1548. this._registerFunction('torch.keys', function(dict) {
  1549. return Object.keys(dict);
  1550. });
  1551. this._registerFunction('torch.len', function(value) {
  1552. if (value) {
  1553. return value.length;
  1554. }
  1555. return NaN;
  1556. });
  1557. this._registerFunction('torch.le', function(left, right) {
  1558. if (typeof left === 'number' && typeof right === 'number') {
  1559. if (isNaN(left) || isNaN(right)) {
  1560. return false;
  1561. }
  1562. return left <= right;
  1563. }
  1564. throw new pytorch.Error("Unknown 'torch.le' expression type.");
  1565. });
  1566. this._registerFunction('torch.list', function(args) {
  1567. return args;
  1568. });
  1569. this._registerFunction('torch.list_with_default', function(size /*, defaults */) {
  1570. return size;
  1571. });
  1572. this._registerFunction('torch.lt', function(left, right) {
  1573. if (typeof left === 'number' && typeof right === 'number') {
  1574. return left < right;
  1575. }
  1576. throw new pytorch.Error("Unknown 'torch.lt' expression type.");
  1577. });
  1578. this._registerFunction('torch.mul', function(left, right) {
  1579. if (typeof left === 'number' && typeof right === 'number') {
  1580. return left * right;
  1581. }
  1582. if (isNaN(left) || isNaN(right)) {
  1583. return NaN;
  1584. }
  1585. throw new pytorch.Error("Unknown 'torch.mul' expression type.");
  1586. });
  1587. this._registerFunction('torch.ne', function(left, right) {
  1588. if (typeof left === 'number' && typeof right === 'number') {
  1589. if (isNaN(left) || isNaN(right)) {
  1590. return false;
  1591. }
  1592. return left !== right;
  1593. }
  1594. if (Array.isArray(left) && Array.isArray(right) && left.length === right.length) {
  1595. return false;
  1596. }
  1597. throw new pytorch.Error("Unknown 'torch.ne' expression type.");
  1598. });
  1599. this._registerFunction('torch.neg', function(value) {
  1600. if (typeof value === 'number') {
  1601. return -value;
  1602. }
  1603. throw new pytorch.Error("Unknown 'torch.neg' expression type.");
  1604. });
  1605. this._registerFunction('torch.q_scale', function(/* tensor */) {
  1606. return -1; // TODO
  1607. });
  1608. this._registerFunction('torch.t', function(tensor) {
  1609. return tensor;
  1610. });
  1611. this._registerFunction('torch.size', function(tensor, dim) {
  1612. if (tensor && Array.isArray(tensor.size)) {
  1613. if (dim === undefined) {
  1614. return tensor.size;
  1615. }
  1616. if (Number.isInteger(dim)) {
  1617. if (dim >= 0 && dim < tensor.size.length) {
  1618. return tensor.size[dim];
  1619. }
  1620. if (dim < 0 && -dim < tensor.size.length) {
  1621. return tensor.size[tensor.size.length + dim];
  1622. }
  1623. }
  1624. throw new pytorch.Error('Dimension out of range (expected to be in range of ' + JSON.stringify(tensor.size) + ', but got ' + JSON.stringify(dim) + ').');
  1625. }
  1626. return NaN;
  1627. });
  1628. this._registerFunction('torch.slice', function(l, start, end, step) {
  1629. if (step !== 1) {
  1630. throw new pytorch.Error('Slicing only supports step=1');
  1631. }
  1632. start = Math.max(0, start);
  1633. end = Math.min(l.length, end);
  1634. return l.slice(start, end);
  1635. });
  1636. this._registerFunction('torch.sub', function(left, right) {
  1637. if (typeof left === 'number' && typeof right === 'number') {
  1638. return left * right;
  1639. }
  1640. throw new pytorch.Error("Unknown 'torch.sub' expression type.");
  1641. });
  1642. this._registerFunction('torch.values', function(dict) {
  1643. return Object.keys(dict).map((key) => dict[key]);
  1644. });
  1645. this._registerFunction('torch.warn', function() {
  1646. });
  1647. this._registerFunction('uninitialized', function(type) {
  1648. if (type && type.__module__ === 'typing' && type.__name__ === 'Tuple') {
  1649. return [];
  1650. }
  1651. if (type && type.__module__ === 'typing' && type.__name__ === 'List') {
  1652. return [];
  1653. }
  1654. if (type && type.__module__ === 'typing' && type.__name__ === 'Dict') {
  1655. return {};
  1656. }
  1657. if (type && type.__module__ === 'torch' && type.__name__ === 'Tensor') {
  1658. return { __module__: type.__module__, __name__: type.__name__ };
  1659. }
  1660. throw new pytorch.Error("Unsupported uninitialized argument '" + JSON.stringify(type) + "'.");
  1661. });
  1662. }
  1663. get context() {
  1664. return this._context;
  1665. }
  1666. parse(file) {
  1667. const data = this._sources[file];
  1668. if (data) {
  1669. const code = this._utf8Decoder.decode(data);
  1670. const reader = new this._python.Parser(code, file);
  1671. const program = reader.parse();
  1672. if (!program) {
  1673. throw new pytorch.Error("Module '" + file + "' parse error.");
  1674. }
  1675. return program;
  1676. }
  1677. return null;
  1678. }
  1679. package(name, file, raw) {
  1680. if (this._python && !this._packages.has(name)) {
  1681. file = file || 'code/' + name.split('.').join('/') + '.py';
  1682. const program = this.parse(file);
  1683. if (program) {
  1684. let globals = this._context.getx(name);
  1685. if (globals === undefined) {
  1686. globals = {};
  1687. this._context.setx(name, globals);
  1688. }
  1689. globals.__class__ = this._context.scope.builtins.module;
  1690. globals.__name__ = name;
  1691. globals.__file__ = file;
  1692. this._packages.set(name, globals);
  1693. const context = this._context.push(globals);
  1694. this._block(program.body, context);
  1695. if (raw) {
  1696. return program;
  1697. }
  1698. }
  1699. }
  1700. return this._packages.get(name);
  1701. }
  1702. type(name) {
  1703. const type = this._context.getx(name);
  1704. if (type !== undefined) {
  1705. return type;
  1706. }
  1707. const parts = name.split('.');
  1708. const className = parts.pop();
  1709. const moduleName = parts.join('.');
  1710. const module = this.package(moduleName);
  1711. if (module) {
  1712. return module[className];
  1713. }
  1714. return null;
  1715. }
  1716. invoke(name, args) {
  1717. const target = this.type(name);
  1718. if (target) {
  1719. if (target.__class__ === this._context.scope.builtins.type) {
  1720. const obj = {};
  1721. obj.__proto__ = target;
  1722. if (obj.__init__ && typeof obj.__init__ === 'function') {
  1723. obj.__init__.apply(obj, args);
  1724. }
  1725. return obj;
  1726. }
  1727. else if (target.__class__ === this._context.scope.builtins.function) {
  1728. if (target.__call__) {
  1729. return target.__call__(args);
  1730. // throw new pytorch.Error('Unexpected function __call__.');
  1731. }
  1732. else {
  1733. return target.apply(null, args);
  1734. }
  1735. }
  1736. }
  1737. this._raiseUnkownName(name);
  1738. const typeParts = name.split('.');
  1739. const typeName = typeParts.pop();
  1740. const typeModule = typeParts.join('.');
  1741. return {
  1742. __module__: typeModule,
  1743. __name__: typeName
  1744. };
  1745. }
  1746. call(target, name, args, context) {
  1747. const callTarget = this._target(target, context);
  1748. const callArguments = args.map((argument) => this.expression(argument, context));
  1749. if (!callTarget || (name !== null && !callTarget[name])) {
  1750. const targetName = pytorch.Utility.target(target) + '.' + name;
  1751. if (this.type(targetName)) {
  1752. return this.invoke(targetName, callArguments);
  1753. }
  1754. throw new pytorch.Error("Unsupported function '" + targetName + "'.");
  1755. }
  1756. const func = name ? callTarget[name] : callTarget;
  1757. if (func.__class__ === this._context.scope.builtins.type) {
  1758. const obj = {};
  1759. obj.__proto__ = func;
  1760. if (obj.__init__ && typeof obj.__init__ === 'function') {
  1761. obj.__init__.apply(obj, args);
  1762. }
  1763. return obj;
  1764. }
  1765. if (func.__class__ === this._context.scope.builtins.function) {
  1766. if (func.__call__) {
  1767. return func.__call__(callArguments);
  1768. }
  1769. }
  1770. if (func.__class__ === this._context.scope.builtins.method) {
  1771. if (func.__call__) {
  1772. return func.__call__([ callTarget ].concat(callArguments));
  1773. }
  1774. }
  1775. if (typeof func === 'function') {
  1776. return func.apply(callTarget, callArguments);
  1777. }
  1778. throw new pytorch.Error("Unsupported call expression.");
  1779. }
  1780. apply(method, args, context) {
  1781. const locals = Array.prototype.slice.call(args);
  1782. context = context.push();
  1783. for (const parameter of method.parameters) {
  1784. context.set(parameter.name, locals.shift());
  1785. }
  1786. return this._block(method.body.statements, context);
  1787. }
  1788. _block(statements, context) {
  1789. statements = Array.prototype.slice.call(statements);
  1790. while (statements.length > 0) {
  1791. const statement = statements.shift();
  1792. switch (statement.type) {
  1793. case 'pass': {
  1794. break;
  1795. }
  1796. case 'return': {
  1797. return this.expression(statement.expression, context);
  1798. }
  1799. case 'def': {
  1800. const module = context.get('__name__');
  1801. const self = this;
  1802. const parent = context.get('__class__');
  1803. let type = null;
  1804. if (parent === this._context.scope.builtins.type) {
  1805. type = this._context.scope.builtins.method;
  1806. }
  1807. else if (parent === this._context.scope.builtins.module) {
  1808. type = this._context.scope.builtins.function;
  1809. }
  1810. else {
  1811. throw new pytorch.Error('Invalid function scope.');
  1812. }
  1813. const func = {
  1814. __class__: type,
  1815. __globals__: context,
  1816. __module__: module,
  1817. __name__: statement.name,
  1818. __code__: statement,
  1819. __call__: function(args) {
  1820. return self.apply(this.__code__, args, this.__globals__);
  1821. }
  1822. };
  1823. context.set(statement.name, func);
  1824. break;
  1825. }
  1826. case 'class': {
  1827. const scope = {
  1828. __class__:this._context.scope.builtins.type,
  1829. __module__: context.get('__name__'),
  1830. __name__: statement.name,
  1831. };
  1832. context.set(statement.name, scope);
  1833. context = context.push(scope);
  1834. this._block(statement.body.statements, context);
  1835. context = context.pop();
  1836. break;
  1837. }
  1838. case 'var': {
  1839. context.set(statement.name, undefined);
  1840. break;
  1841. }
  1842. case '=': {
  1843. this.expression(statement, context);
  1844. break;
  1845. }
  1846. case 'if': {
  1847. const condition = this.expression(statement.condition, context);
  1848. if (condition === true || condition) {
  1849. statements = statement.then.statements.concat(statements);
  1850. break;
  1851. }
  1852. else if (condition === false) {
  1853. statements = statement.else.statements.concat(statements);
  1854. break;
  1855. }
  1856. throw new pytorch.Error("Unknown condition.");
  1857. }
  1858. case 'for': {
  1859. if (statement.target.length == 1 &&
  1860. statement.variable.length === 1 && statement.variable[0].type === 'id') {
  1861. const range = this.expression(statement.target[0], context);
  1862. const variable = statement.variable[0];
  1863. let loop = [];
  1864. for (const value of range) {
  1865. loop.push({ type: '=', target: variable, expression: { type: 'number', value: value }});
  1866. loop = loop.concat(statement.body.statements);
  1867. }
  1868. statements = loop.concat(statements);
  1869. break;
  1870. }
  1871. throw new pytorch.Error("Unsupported 'for' statement.");
  1872. }
  1873. case 'call': {
  1874. this.expression(statement, context);
  1875. break;
  1876. }
  1877. case 'import': {
  1878. for (const module of statement.modules) {
  1879. const moduleName = pytorch.Utility.target(module.name);
  1880. const globals = this.package(moduleName);
  1881. if (module.as) {
  1882. context.set(module.as, globals);
  1883. }
  1884. }
  1885. break;
  1886. }
  1887. default: {
  1888. throw new pytorch.Error("Unknown statement '" + statement.type + "'.");
  1889. }
  1890. }
  1891. }
  1892. }
  1893. expression(expression, context) {
  1894. const self = context.getx('self');
  1895. switch (expression.type) {
  1896. case '=': {
  1897. const target = expression.target;
  1898. if (target.type === 'id') {
  1899. context.set(target.value, this.expression(expression.expression, context));
  1900. return;
  1901. }
  1902. else if (target.type === '[]') {
  1903. if (target.target.type === 'id' &&
  1904. target.arguments.type === 'list' &&
  1905. target.arguments.value.length === 1) {
  1906. const index = this.expression(target.arguments.value[0], context);
  1907. if (target.target.value === '__annotations__') {
  1908. context.set(target.target.value, context.get(target.target.value) || {});
  1909. }
  1910. context.get(target.target.value)[index] = this.expression(expression.expression, context);
  1911. return;
  1912. }
  1913. }
  1914. else if (target.type === '.' &&
  1915. target.member.type === 'id') {
  1916. this.expression(target.target, context)[target.member.value] = this.expression(expression.expression, context);
  1917. return;
  1918. }
  1919. else if (target.type === 'tuple') {
  1920. const value = this.expression(expression.expression, context);
  1921. if (target.value.length == value.length && target.value.every((item) => item.type === 'id')) {
  1922. for (let i = 0; i < value.length; i++) {
  1923. context.set(target.value[i].value, value[i]);
  1924. }
  1925. return;
  1926. }
  1927. }
  1928. break;
  1929. }
  1930. case 'list': {
  1931. return expression.value.map((item) => this.expression(item, context));
  1932. }
  1933. case 'string': {
  1934. return expression.value.substring(1, expression.value.length - 1);
  1935. }
  1936. case 'number': {
  1937. return Number(expression.value);
  1938. }
  1939. case '[]': {
  1940. if (expression.target.type === 'id' &&
  1941. expression.arguments.type === 'list' &&
  1942. expression.arguments.value.length === 1) {
  1943. if (context.get(expression.target.value)) {
  1944. const index = this.expression(expression.arguments.value[0], context);
  1945. return context.get(expression.target.value)[index];
  1946. }
  1947. }
  1948. const target = this.expression(expression.target, context);
  1949. if (target && expression.arguments.type === 'list' &&
  1950. (target.__class__ === this.context.scope.typing._VariadicGenericAlias ||
  1951. target.__class__ === this.context.scope.typing._GenericAlias ||
  1952. target.__class__ === this.context.scope.typing._SpecialForm)) {
  1953. const type = Object.assign({}, target);
  1954. type.__args__ = expression.arguments.value.map((arg) => this.expression(arg, context));
  1955. return type;
  1956. }
  1957. if (expression.arguments.type === 'list' && expression.arguments.value.length === 1) {
  1958. const index = this.expression(expression.arguments.value[0], context);
  1959. return target[index];
  1960. }
  1961. break;
  1962. }
  1963. case '.': {
  1964. if (expression.member.type == 'id') {
  1965. const target = this._target(expression.target, context);
  1966. return target[expression.member.value];
  1967. }
  1968. throw new pytorch.Error("Unsupported field expression.");
  1969. }
  1970. case 'call': {
  1971. if (expression.target.type === 'id' && expression.target.value === 'annotate' && expression.arguments.length === 2) {
  1972. return this.expression(expression.arguments[1], context);
  1973. }
  1974. if (expression.target.type === 'id' && expression.target.value === 'unchecked_cast' && expression.arguments.length === 2) {
  1975. return this.expression(expression.arguments[1], context);
  1976. }
  1977. if (expression.target.type === '.') {
  1978. return this.call(expression.target.target, expression.target.member.value, expression.arguments, context);
  1979. }
  1980. return this.call(expression.target, null, expression.arguments, context);
  1981. }
  1982. case 'id': {
  1983. switch (expression.value) {
  1984. case 'self': return self;
  1985. case 'None': return null;
  1986. case 'True': return true;
  1987. case 'False': return false;
  1988. }
  1989. const type =
  1990. this._context.scope.builtins[expression.value] ||
  1991. this._context.scope.typing[expression.value] ||
  1992. this._context.scope.torch[expression.value];
  1993. if (type &&
  1994. (type.__class__ === this._context.scope.builtins.type ||
  1995. type.__class__ === this._context.scope.typing._VariadicGenericAlias ||
  1996. type.__class__ === this._context.scope.typing._GenericAlias ||
  1997. type.__class__ === this._context.scope.typing._SpecialForm)) {
  1998. return type;
  1999. }
  2000. return context.get(expression.value);
  2001. }
  2002. case 'tuple': {
  2003. return expression.value.map((expression) => this.expression(expression, context));
  2004. }
  2005. }
  2006. throw new pytorch.Error("Unknown expression '" + expression.type + "'.");
  2007. }
  2008. _target(expression, context) {
  2009. let current = expression;
  2010. let packageName = '';
  2011. for (;;) {
  2012. if (current.type === '.' && current.member && current.member.type === 'id') {
  2013. packageName = '.' + current.member.value + packageName;
  2014. current = current.target;
  2015. }
  2016. else if (current.type === 'id' && current.value !== 'self' && current.value !== 'CONSTANTS') {
  2017. packageName = current.value + packageName;
  2018. break;
  2019. }
  2020. else {
  2021. packageName = null;
  2022. break;
  2023. }
  2024. }
  2025. if (packageName) {
  2026. let target = context.getx(packageName);
  2027. if (!target) {
  2028. target = this.package(packageName);
  2029. if (!target) {
  2030. throw new pytorch.Error("Failed to resolve module '" + packageName + "'.");
  2031. }
  2032. }
  2033. return target;
  2034. }
  2035. return this.expression(expression, context);
  2036. }
  2037. _registerFunction(name, callback) {
  2038. if (this._context.getx(name)) {
  2039. throw new pytorch.Error("Function '" + name + "' is already registered.");
  2040. }
  2041. const parts = name.split('.');
  2042. callback.__class__ = this._context.scope.builtins.function;
  2043. callback.__name__ = parts.pop();
  2044. callback.__module__ = parts.join('.');
  2045. this._context.setx(name, callback);
  2046. }
  2047. _registerConstructor(name, callback) {
  2048. if (this._context.getx(name)) {
  2049. throw new pytorch.Error("Constructor '" + name + "' is already registered.");
  2050. }
  2051. const parts = name.split('.');
  2052. const typeName = parts.pop();
  2053. const typeModule = parts.join('.');
  2054. const type = {
  2055. __class__: this._context.scope.builtins.type,
  2056. __name__: typeName,
  2057. __module__: typeModule,
  2058. __init__: function() {
  2059. callback.apply(this, arguments);
  2060. }
  2061. };
  2062. this._context.setx(name, type);
  2063. }
  2064. _raiseUnkownName(name) {
  2065. if (name && !this._unknownNameMap.has(name)) {
  2066. this._unknownNameMap.add(name);
  2067. if (this._knownPackageMap.has(name.split('.').shift())) {
  2068. this._exceptionCallback(new pytorch.Error("Unknown function '" + name + "'."), false);
  2069. }
  2070. }
  2071. }
  2072. };
  2073. pytorch.Execution.Context = class {
  2074. constructor(parent, scope) {
  2075. this._parent = parent || null;
  2076. this._scope = scope || {};
  2077. }
  2078. push(scope) {
  2079. return new pytorch.Execution.Context(this, scope);
  2080. }
  2081. pop() {
  2082. return this._parent;
  2083. }
  2084. get scope() {
  2085. return this._scope;
  2086. }
  2087. set(name, value) {
  2088. this._scope[name] = value;
  2089. }
  2090. get(name) {
  2091. if (name in this._scope) {
  2092. return this._scope[name];
  2093. }
  2094. if (this._parent) {
  2095. return this._parent.get(name);
  2096. }
  2097. return undefined;
  2098. }
  2099. setx(name, value) {
  2100. const parts = name.split('.');
  2101. if (parts.length == 1) {
  2102. this.set(parts[0], value);
  2103. }
  2104. else {
  2105. let parent = this.get(parts[0]);
  2106. if (!parent) {
  2107. parent = {};
  2108. this.set(parts[0], parent);
  2109. }
  2110. parts.shift();
  2111. while (parts.length > 1) {
  2112. const part = parts.shift();
  2113. parent[part] = parent[part] || {};
  2114. parent = parent[part];
  2115. }
  2116. parent[parts[0]] = value;
  2117. }
  2118. }
  2119. getx(name) {
  2120. const parts = name.split('.');
  2121. let value = this.get(parts[0]);
  2122. if (value) {
  2123. parts.shift();
  2124. while (parts.length > 0 && value[parts[0]]) {
  2125. value = value[parts[0]];
  2126. parts.shift();
  2127. }
  2128. if (parts.length === 0) {
  2129. return value;
  2130. }
  2131. }
  2132. return undefined;
  2133. }
  2134. };
  2135. pytorch.Container = class {
  2136. static open(context, metadata, pickle, python, exception) {
  2137. if (context.entries('zip').some((entry) => entry.name === 'model.json' || entry.name === 'data.pkl' || entry.name.endsWith('/model.json') || entry.name.endsWith('/data.pkl'))) {
  2138. return new pytorch.Container.Zip(context.entries('zip'), metadata, pickle, python, exception);
  2139. }
  2140. const buffer = context.buffer;
  2141. const signature = [ 0x8a, 0x0a, 0x6c, 0xfc, 0x9c, 0x46, 0xf9, 0x20, 0x6a, 0xa8, 0x50, 0x19 ];
  2142. if (buffer && buffer.length > 14 && buffer[0] == 0x80 && buffer[1] < 0x10 && signature.every((v, i) => v == buffer[i + 2])) {
  2143. return new pytorch.Container.Pickle(buffer, pickle, exception);
  2144. }
  2145. if (context.entries('tar').some((entry) => entry.name == 'pickle')) {
  2146. return new pytorch.Container.Tar(context.entries('tar'), pickle, exception);
  2147. }
  2148. return null;
  2149. }
  2150. };
  2151. pytorch.Container.Tar = class {
  2152. constructor(entries, pickle, exceptionCallback) {
  2153. this._entries = entries;
  2154. this._pickle = pickle;
  2155. this._exceptionCallack = exceptionCallback;
  2156. }
  2157. get format() {
  2158. return 'PyTorch v0.1.1';
  2159. }
  2160. get data() {
  2161. this._unpickle();
  2162. return this._data;
  2163. }
  2164. get state() {
  2165. this._unpickle();
  2166. return this._state;
  2167. }
  2168. get littleEndian() {
  2169. this._unpickle();
  2170. return this._littleEndian;
  2171. }
  2172. _unpickle() {
  2173. if (!this._entries) {
  2174. return;
  2175. }
  2176. this._data = null;
  2177. this._state = null;
  2178. this._littleEndian = true;
  2179. const execution = new pytorch.Execution(null, [], this._exceptionCallback);
  2180. const entries = {};
  2181. for (const entry of this._entries) {
  2182. switch (entry.name) {
  2183. case 'sys_info': entries.sys_info = entry.data; break;
  2184. case 'pickle': entries.pickle = entry.data; break;
  2185. case 'storages': entries.storages = entry.data; break;
  2186. case 'tensors': entries.tensors = entry.data; break;
  2187. }
  2188. }
  2189. this._exceptionCallback = null;
  2190. this._entries = null;
  2191. if (entries.sys_info) {
  2192. const unpickler = new this._pickle.Unpickler(entries.sys_info);
  2193. const sys_info = unpickler.load((name, args) => execution.invoke(name, args));
  2194. if (sys_info.protocol_version != 1000) {
  2195. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2196. }
  2197. if (sys_info.type_sizes &&
  2198. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  2199. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  2200. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  2201. throw new pytorch.Error('Unsupported type sizes.');
  2202. }
  2203. this._littleEndian = sys_info.little_endian;
  2204. }
  2205. const deserialized_objects = {};
  2206. if (entries.storages) {
  2207. const unpickler = new this._pickle.Unpickler(entries.storages);
  2208. const num_storages = unpickler.load((name, args) => execution.invoke(name, args));
  2209. for (let i = 0; i < num_storages; i++) {
  2210. const storage_args = unpickler.load();
  2211. const storage_key = storage_args[0];
  2212. const storage_type = storage_args[2];
  2213. const size = pytorch.Utility.readInt64(unpickler.read(8));
  2214. const storage = execution.invoke(storage_type, [ size ]);
  2215. storage.data = unpickler.read(storage.dataTypeSize * storage.size);
  2216. deserialized_objects[storage_key] = storage;
  2217. }
  2218. /*
  2219. let storage_views = unpickler.load();
  2220. for target_cdata, root_cdata, offset, size in storage_views:
  2221. root = deserialized_objects[root_cdata]
  2222. deserialized_objects[target_cdata] = root[offset:offset + size]
  2223. */
  2224. }
  2225. if (entries.tensors) {
  2226. const unpickler = new this._pickle.Unpickler(entries.tensors);
  2227. const num_tensors = unpickler.load((name, args) => execution.invoke(name, args));
  2228. for (let j = 0; j < num_tensors; j++) {
  2229. const tensor_args = unpickler.load();
  2230. const tensor_key = tensor_args[0];
  2231. const storage_id = tensor_args[1];
  2232. const storage = deserialized_objects[storage_id];
  2233. const ndim = pytorch.Utility.readInt32(unpickler.read(4));
  2234. unpickler.read(4);
  2235. const shape = [];
  2236. for (let k = 0; k < ndim; k++) {
  2237. shape.push(pytorch.Utility.readInt64(unpickler.read(8)));
  2238. }
  2239. const stride = [];
  2240. for (let l = 0; l < ndim; l++) {
  2241. stride.push(pytorch.Utility.readInt64(unpickler.read(8)));
  2242. }
  2243. const storage_offset = pytorch.Utility.readInt64(unpickler.read(8));
  2244. const tensor_type_name = storage.__name__.replace('Storage', 'Tensor');
  2245. const tensor = execution.invoke(storage.__module__ + '.' + tensor_type_name, []);
  2246. tensor.__setstate__([ storage, storage_offset, shape, stride ]);
  2247. deserialized_objects[tensor_key] = tensor;
  2248. }
  2249. }
  2250. if (entries.pickle) {
  2251. const unpickler = new this._pickle.Unpickler(entries.pickle);
  2252. const persistent_load = (saved_id) => {
  2253. return deserialized_objects[saved_id];
  2254. };
  2255. let obj = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  2256. if (obj) {
  2257. if (!(obj instanceof Map)) {
  2258. const map = new Map();
  2259. for (const key of Object.keys(obj)) {
  2260. map.set(key, obj[key]);
  2261. }
  2262. obj = map;
  2263. }
  2264. this._state = [];
  2265. const state_map = {};
  2266. if (obj instanceof Map) {
  2267. for (const item of obj) {
  2268. const key = item[0];
  2269. const value = item[1];
  2270. if (!key || !value) {
  2271. this._state = null;
  2272. break;
  2273. }
  2274. const state = {};
  2275. state.id = key;
  2276. state.value = null;
  2277. if (value && value.__module__ === 'torch.nn.parameter' && value.__name__ === 'Parameter') {
  2278. state.value = value[0];
  2279. }
  2280. else if (pytorch.Utility.isTensor(value)) {
  2281. state.value = value;
  2282. }
  2283. if (!state.value) {
  2284. this._state = null;
  2285. break;
  2286. }
  2287. const split = state.id.split('.');
  2288. if (split.length < 2) {
  2289. this._state = null;
  2290. break;
  2291. }
  2292. state.name = split.pop();
  2293. const state_group_name = split.join('.');
  2294. let state_group = state_map[state_group_name];
  2295. if (!state_group) {
  2296. state_group = {};
  2297. state_group.name = state_group_name;
  2298. state_group.states = [];
  2299. state_map[state_group_name] = state_group;
  2300. this._state.push(state_group);
  2301. }
  2302. state_group.states.push({ name: state.name, arguments: [ state ] });
  2303. }
  2304. }
  2305. }
  2306. }
  2307. }
  2308. };
  2309. pytorch.Container.Pickle = class {
  2310. constructor(buffer, pickle, exception) {
  2311. this._buffer = buffer;
  2312. this._pickle = pickle;
  2313. this._exceptionCallback = exception;
  2314. }
  2315. get format() {
  2316. return 'PyTorch v0.1.10';
  2317. }
  2318. get data() {
  2319. this._unpickle();
  2320. return this._data;
  2321. }
  2322. get state() {
  2323. this._unpickle();
  2324. return this._state;
  2325. }
  2326. get littleEndian() {
  2327. this._unpickle();
  2328. return this._littleEndian;
  2329. }
  2330. _unpickle() {
  2331. if (!this._buffer) {
  2332. return;
  2333. }
  2334. const execution = new pytorch.Execution(null, [], this._exceptionCallback);
  2335. const unpickler = new this._pickle.Unpickler(this._buffer);
  2336. this._buffer = null;
  2337. this._pickle = null;
  2338. this._exceptionCallback = null;
  2339. unpickler.load(); // magic_number
  2340. const protocol_version = unpickler.load();
  2341. if (protocol_version != 1001) {
  2342. throw new pytorch.Error("Unsupported protocol version '" + protocol_version + "'.");
  2343. }
  2344. const sys_info = unpickler.load();
  2345. if (sys_info.protocol_version != 1001) {
  2346. throw new pytorch.Error("Unsupported protocol version '" + sys_info.protocol_version + "'.");
  2347. }
  2348. if (sys_info.type_sizes &&
  2349. ((sys_info.type_sizes.int && sys_info.type_sizes.int != 4) ||
  2350. (sys_info.type_sizes.long && sys_info.type_sizes.long != 4) ||
  2351. (sys_info.type_sizes.short && sys_info.type_sizes.short != 2))) {
  2352. throw new pytorch.Error('Unsupported type sizes.');
  2353. }
  2354. this._littleEndian = sys_info.little_endian;
  2355. const module_source_map = new Map();
  2356. const deserialized_objects = new Map();
  2357. const persistent_load = (saved_id) => {
  2358. const typename = saved_id.shift();
  2359. const data = saved_id;
  2360. switch (typename) {
  2361. case 'module': {
  2362. const module = data[0];
  2363. const source = data[2];
  2364. module_source_map.set(module, source);
  2365. return data[0];
  2366. }
  2367. case 'storage': {
  2368. const data_type = data.shift();
  2369. const root_key = data.shift();
  2370. data.shift(); // location
  2371. const size = data.shift();
  2372. const view_metadata = data.shift();
  2373. if (!deserialized_objects.has(root_key)) {
  2374. const storage = execution.invoke(data_type, [ size ]);
  2375. deserialized_objects.set(root_key, storage);
  2376. }
  2377. if (view_metadata) {
  2378. const view_key = view_metadata.shift();
  2379. view_metadata.shift(); // view_offset
  2380. view_metadata.shift(); // view_size
  2381. if (!deserialized_objects.has(view_key)) {
  2382. const view = null; // storage.slice(view_offset, view_offset + view_size);
  2383. deserialized_objects.set(view_key, view);
  2384. }
  2385. return deserialized_objects.get(view_key);
  2386. }
  2387. return deserialized_objects.get(root_key);
  2388. }
  2389. }
  2390. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2391. };
  2392. const data = unpickler.load((name, args) => execution.invoke(name, args), persistent_load);
  2393. if (!data) {
  2394. throw new pytorch.Error('File format is not PyTorch.');
  2395. }
  2396. const deserialized_storage_keys = unpickler.load();
  2397. for (const deserialized_storage_key of deserialized_storage_keys) {
  2398. const storage = deserialized_objects.get(deserialized_storage_key);
  2399. const size = pytorch.Utility.readInt64(unpickler.read(8));
  2400. if (size != storage.size) {
  2401. throw new pytorch.Error('Storage size mismatch.');
  2402. }
  2403. storage.data = unpickler.read(storage.dataTypeSize * storage.size);
  2404. }
  2405. this._data = pytorch.Utility.findRootModule(data);
  2406. if (!this._data) {
  2407. this._state = pytorch.Utility._findStateDict(data);
  2408. }
  2409. if (!this._data && !this._state && data !== 'None') {
  2410. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2411. }
  2412. }
  2413. };
  2414. pytorch.Container.Zip = class {
  2415. constructor(entries, metadata, pickle, python, exceptionCallback) {
  2416. this._entries = entries;
  2417. this._metadata = metadata;
  2418. this._pickle = pickle;
  2419. this._python = python;
  2420. this._exceptionCallback = exceptionCallback;
  2421. // https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md
  2422. const entry = this._entries.find((entry) => entry.name == 'model.json' || entry.name == 'data.pkl' || entry.name.endsWith('/model.json') || entry.name.endsWith('/data.pkl'));
  2423. if (!entry) {
  2424. throw new pytorch.Error("PyTorch Zip container does not contain 'data.pkl' or 'model.json'.");
  2425. }
  2426. const lastIndex = entry.name.lastIndexOf('/');
  2427. this._prefix = lastIndex === -1 ? '' : entry.name.substring(0, lastIndex + 1);
  2428. this._utf8Decoder = new TextDecoder('utf-8');
  2429. }
  2430. get format() {
  2431. if (this._format === undefined) {
  2432. if (this._entry('model.json')) {
  2433. this._format = this._entry('attributes.pkl') ? 'TorchScript v1.1' : 'TorchScript v1.0';
  2434. }
  2435. else if (this._entry('data.pkl')) {
  2436. // kProducedFileFormatVersion in https://github.com/pytorch/pytorch/blob/master/caffe2/serialize/inline_container.h
  2437. const versionEntry = this._entry('version');
  2438. const versionNumber = versionEntry ? this._utf8Decoder.decode(versionEntry.data).split('\n').shift() : '';
  2439. const versionTable = { '1': 'v1.3', '2': 'v1.4', '3': 'v1.6', '4': 'v1.7' };
  2440. const version = versionTable[versionNumber];
  2441. if (!version) {
  2442. this._exceptionCallback(new pytorch.Error("Unsupported PyTorch Zip version '" + versionNumber + "'."));
  2443. }
  2444. this._format = (this._entry('constants.pkl') ? 'TorchScript' : 'PyTorch') + ' ' + (version || 'v-' + versionNumber.toString() );
  2445. }
  2446. }
  2447. return this._format;
  2448. }
  2449. get producer() {
  2450. return this.data ? this._producer : '';
  2451. }
  2452. get name() {
  2453. return this._name;
  2454. }
  2455. get data() {
  2456. this._load();
  2457. return this._data;
  2458. }
  2459. get state() {
  2460. this._load();
  2461. return this._state;
  2462. }
  2463. get constants() {
  2464. if (this._constants === undefined) {
  2465. this._constants = [];
  2466. const entry = this._entry('constants.pkl');
  2467. if (entry && entry.data) {
  2468. this._constants = this._unpickle(entry.data, this._storage('constants'));
  2469. }
  2470. }
  2471. return this._constants;
  2472. }
  2473. get execution() {
  2474. if (this._execution === undefined) {
  2475. this._types = new Map(); // TODO
  2476. const sources = {};
  2477. for (const entry of this._entries) {
  2478. if (entry.name.startsWith(this._prefix + 'code')) {
  2479. const file = entry.name.substring(this._prefix.length);
  2480. if (sources[file]) {
  2481. throw new pytorch.Error("Duplicate source file '" + file + "'.");
  2482. }
  2483. sources[file] = entry.data;
  2484. }
  2485. }
  2486. this._execution = new pytorch.Container.Zip.Execution(this._python, sources, this._exceptionCallback, this._metadata);
  2487. const constants = {};
  2488. for (let i = 0; i < this.constants.length; i++) {
  2489. constants['c' + i.toString()] = this.constants[i];
  2490. }
  2491. this._execution.context.set('CONSTANTS', constants);
  2492. }
  2493. return this._execution;
  2494. }
  2495. _entry(name) {
  2496. return this._entries.find((entry) => entry.name == this._prefix + name);
  2497. }
  2498. _load() {
  2499. if (this._data === undefined) {
  2500. this._data = null;
  2501. const dataEntry = this._entry('data.pkl');
  2502. if (dataEntry && dataEntry.data) {
  2503. this._data = this._unpickle(dataEntry.data, this._storage('data'));
  2504. }
  2505. else {
  2506. const modelEntry = this._entry('model.json');
  2507. if (modelEntry) {
  2508. const model = JSON.parse(this._utf8Decoder.decode(modelEntry.data));
  2509. this._producer = model.producerName + (model.producerVersion ? ' v' + model.producerVersion : '');
  2510. this._data = model.mainModule || {};
  2511. this._name = this._data.name || '';
  2512. if (this._data.torchscriptArena) {
  2513. this._torchscriptArena = this._data.torchscriptArena.key;
  2514. }
  2515. const queue = [ this._data ];
  2516. const entries = new Map();
  2517. for (const entry of this._entries) {
  2518. entries.set(entry.name, entry.data);
  2519. }
  2520. const tensorTypeMap = new Map([
  2521. [ 'FLOAT', 'Float' ],
  2522. [ 'FLOAT16', 'Half' ],
  2523. [ 'DOUBLE', 'Double' ],
  2524. [ 'INT8', 'Char' ],
  2525. [ 'INT32', 'Int' ],
  2526. [ 'INT64', 'Long' ]
  2527. ]);
  2528. this._constants = model.tensors || [];
  2529. for (const tensor of this._constants) {
  2530. const key = this._prefix + tensor.data.key;
  2531. if (!tensorTypeMap.has(tensor.dataType)) {
  2532. throw new pytorch.Error("Unknown tensor data type '" + tensor.dataType + "'.");
  2533. }
  2534. const type = tensorTypeMap.get(tensor.dataType);
  2535. tensor.__module__ = 'torch';
  2536. tensor.__name__ = 'Tensor';
  2537. tensor.name = tensor.data.key;
  2538. tensor.size = tensor.dims ? tensor.dims.map((dim) => parseInt(dim, 10)) : null;
  2539. tensor.storage = this.execution.invoke('torch.' + type + 'Storage', [ tensor.size ]);
  2540. tensor.storage.data = entries.get(key);
  2541. }
  2542. while (queue.length > 0) {
  2543. const module = queue.shift();
  2544. if (!module.__module__ && !module.__name__) {
  2545. module.__module__ = 'torch.nn.modules.module';
  2546. module.__name__ = 'Module';
  2547. }
  2548. if (module.name) {
  2549. module.__id__ = module.name;
  2550. }
  2551. if (module.submodules) {
  2552. for (const submodule of module.submodules) {
  2553. module[submodule.name] = submodule;
  2554. submodule.__parent__ = module;
  2555. queue.push(submodule);
  2556. }
  2557. delete module.submodules;
  2558. }
  2559. let parameters = [];
  2560. if (module.parameters) {
  2561. parameters = parameters.concat(module.parameters);
  2562. delete module.parameters;
  2563. }
  2564. if (module.arguments) {
  2565. parameters = parameters.concat(module.arguments);
  2566. delete module.arguments;
  2567. }
  2568. for (const parameter of parameters) {
  2569. const tensor = this._constants[parameter.tensorId];
  2570. module[parameter.name] = tensor;
  2571. if (!parameter.__module__ || !parameter.__name__) {
  2572. parameter.__module__ = 'torch';
  2573. parameter.__name__ = 'Tensor';
  2574. }
  2575. }
  2576. }
  2577. }
  2578. }
  2579. if (this.format.startsWith('PyTorch ')) {
  2580. const data = this._data;
  2581. this._data = pytorch.Utility.findRootModule(data);
  2582. if (!this._data) {
  2583. this._state = pytorch.Utility._findStateDict(data);
  2584. }
  2585. if (!this._data && !this._state && data !== 'None') {
  2586. throw new pytorch.Error('File does not contain root module or state dictionary.');
  2587. }
  2588. }
  2589. }
  2590. }
  2591. _unpickle(data, storage_map) {
  2592. const deserialized_objects = new Map();
  2593. const persistent_load = (saved_id) => {
  2594. const typename = saved_id.shift();
  2595. if (typename !== 'storage') {
  2596. throw new pytorch.Error("Unknown persistent load type '" + typename + "'.");
  2597. }
  2598. const data_type = saved_id.shift();
  2599. const root_key = saved_id.shift();
  2600. saved_id.shift(); // location
  2601. const size = saved_id.shift();
  2602. let storage = null;
  2603. if (deserialized_objects.has(root_key)) {
  2604. storage = deserialized_objects.get(root_key);
  2605. }
  2606. else {
  2607. storage = this.execution.invoke(data_type, [ size ]);
  2608. storage.data = storage_map.get(root_key);
  2609. deserialized_objects.set(root_key, storage);
  2610. }
  2611. const view_metadata = saved_id.shift();
  2612. if (view_metadata) {
  2613. const view_key = view_metadata.shift();
  2614. view_metadata.shift(); // view_offset
  2615. view_metadata.shift(); // view_size
  2616. let view = null;
  2617. if (deserialized_objects.has(view_key)) {
  2618. view = deserialized_objects.get(root_key);
  2619. }
  2620. else {
  2621. view = null; // storage.slice(view_offset, view_offset + view_size);
  2622. deserialized_objects.set(view_key, view);
  2623. }
  2624. return view;
  2625. }
  2626. return storage;
  2627. };
  2628. return new this._pickle.Unpickler(data).load((name, args) => this.execution.invoke(name, args), persistent_load);
  2629. }
  2630. _storage(dirname) {
  2631. const map = new Map();
  2632. const prefix = this._prefix + dirname + '/';
  2633. for (const entry of this._entries) {
  2634. if (entry.name.startsWith(prefix)) {
  2635. const key = entry.name.substring(prefix.length);
  2636. map.set(key, entry.data);
  2637. }
  2638. }
  2639. return map;
  2640. }
  2641. _type(name) {
  2642. if (!this._types.has(name)) {
  2643. const parts = name.split('.');
  2644. const className = parts.pop();
  2645. const file = 'code/' + parts.join('/') + '.py';
  2646. const program = this.execution.parse(file);
  2647. if (program) {
  2648. for (const statement of program.body) {
  2649. if (statement.type === 'class' && statement.name == className) {
  2650. this._types.set(name, statement);
  2651. break;
  2652. }
  2653. }
  2654. }
  2655. }
  2656. return this._types.get(name);
  2657. }
  2658. trace() {
  2659. this._inputs = [];
  2660. this._outputs = [];
  2661. this.execution.reset();
  2662. if (this._torchscriptArena) {
  2663. const program = this.execution.parse(this._torchscriptArena);
  2664. for (const statement of program.body) {
  2665. if (statement.type == 'def') {
  2666. const self = this;
  2667. const globals = this.execution.context;
  2668. const func = {
  2669. __class__: this.execution.context.scope.builtins.function,
  2670. __name__: statement.name,
  2671. __code__: statement,
  2672. __call__: function(args) {
  2673. return self.execution.apply(this.__code__, args, globals);
  2674. }
  2675. };
  2676. this.data[statement.name] = func;
  2677. }
  2678. }
  2679. }
  2680. if (this.data.forward) {
  2681. const args = [ this.data ]; // self
  2682. if (this.data.forward.__code__ && this.data.forward.__code__.parameters) {
  2683. for (const parameter of this.data.forward.__code__.parameters) {
  2684. if (parameter.name !== 'self') {
  2685. const type = parameter.parameterType;
  2686. if (type.type === 'type' && type.name.type) {
  2687. if (type.name.value === 'Tensor') {
  2688. this._inputs.push(parameter.name);
  2689. args.push({ __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input-tensor' });
  2690. }
  2691. if (type.name.value === 'Tuple' && type.arguments.every((item) => item.type === 'type' && item.name.type === 'id' && item.name.value === 'Tensor')) {
  2692. this._inputs.push(parameter.name);
  2693. args.push(type.arguments.map(() => { return { __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, __origin__: 'trace-input-tuple' }; }));
  2694. }
  2695. if (type.name.value === 'List' && type.arguments.every((item) => item.type === 'type' && item.name.type === 'id' && item.name.value === 'Tensor')) {
  2696. this._inputs.push(parameter.name);
  2697. args.push([ { __module__: 'torch', __name__: 'Tensor', __variable__: parameter.name, size: [ NaN, NaN ], __origin__: 'trace-input-list' } ]);
  2698. }
  2699. }
  2700. }
  2701. }
  2702. }
  2703. const result = this.data.forward.__call__(args);
  2704. const outputs = !Array.isArray(result) ? [ result ] : result;
  2705. for (const output of outputs) {
  2706. if (pytorch.Utility.isTensor(output)) {
  2707. this._outputs.push(output.__variable__);
  2708. }
  2709. }
  2710. this._nodes = this.execution.nodes;
  2711. return true;
  2712. }
  2713. throw new pytorch.Error("Module 'forward' not implemented.");
  2714. }
  2715. get inputs() {
  2716. return this._inputs;
  2717. }
  2718. get outputs() {
  2719. return this._outputs;
  2720. }
  2721. get nodes() {
  2722. return this._nodes;
  2723. }
  2724. };
  2725. pytorch.Container.Zip.Execution = class extends pytorch.Execution {
  2726. constructor(python, sources, exceptionCallback, metadata) {
  2727. super(python, sources, exceptionCallback);
  2728. this._metadata = metadata;
  2729. this.reset();
  2730. }
  2731. reset() {
  2732. this._nodes = [];
  2733. this._variableIndex = 0;
  2734. }
  2735. get nodes() {
  2736. return this._nodes;
  2737. }
  2738. call(target, name, args, context) {
  2739. let callTarget = pytorch.Utility.target(target);
  2740. let outputTypes = null;
  2741. if (callTarget && callTarget + '.' + name === 'ops.prim.NumToTensor' &&
  2742. args.length === 1 && args[0].type === 'call' && args[0].target.member.type == 'id') {
  2743. const innerCall = args[0];
  2744. callTarget = pytorch.Utility.target(innerCall.target.target);
  2745. args = innerCall.arguments;
  2746. name = innerCall.target.member.value;
  2747. outputTypes = [ 'int64' ];
  2748. }
  2749. if (callTarget) {
  2750. const type = callTarget + '.' + name;
  2751. // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
  2752. let schemas = this._metadata.type(type);
  2753. if (schemas) {
  2754. if (!Array.isArray(schemas)) {
  2755. schemas = [ schemas ];
  2756. }
  2757. const evalArgs = args.map((argument) => argument.type === '=' && argument.target && argument.target.type === 'id' ? this.expression(argument.expression, context) : this.expression(argument, context));
  2758. for (const schema of schemas) {
  2759. const copyArgs = Array.prototype.slice.call(args);
  2760. const copyEvalArgs = Array.prototype.slice.call(evalArgs);
  2761. const node = {
  2762. type: schema.name,
  2763. inputs: [],
  2764. attributes: [],
  2765. outputs: []
  2766. };
  2767. const referencedParameters = [];
  2768. let next = false;
  2769. const parameters = Array.prototype.slice.call(schema.inputs || []).concat(Array.prototype.slice.call(schema.attributes || []));
  2770. while (parameters.length > 0 && copyEvalArgs.length > 0) {
  2771. if (copyArgs.every((arg) => arg.type === '=' && arg.target && arg.target.type === 'id') &&
  2772. parameters.every((parameter) => parameter.type !== 'tensor' && parameter.type !== 'tensor[]')) {
  2773. const map = new Map();
  2774. for (const parameter of parameters) {
  2775. map.set(parameter.name, parameter);
  2776. }
  2777. while (copyArgs.length > 0) {
  2778. const arg = copyArgs.shift();
  2779. const value = copyEvalArgs.shift();
  2780. const parameter = map.get(arg.target.value);
  2781. if (!parameter) {
  2782. next = true;
  2783. break;
  2784. }
  2785. if (!pytorch.Utility.isType(value, parameter.type)) {
  2786. if (parameter.optional) {
  2787. continue;
  2788. }
  2789. next = true;
  2790. break;
  2791. }
  2792. node.attributes.push({ name: parameter.name, value: value });
  2793. }
  2794. continue;
  2795. }
  2796. if (next) {
  2797. break;
  2798. }
  2799. const parameter = parameters.shift();
  2800. switch (parameter.type) {
  2801. case 'tensor': {
  2802. let argument = copyEvalArgs[0];
  2803. if (Array.isArray(argument) || (!pytorch.Utility.isTensor(argument) && argument !== null && argument !== undefined)) {
  2804. if (parameter.optional) {
  2805. if (argument === undefined) {
  2806. copyArgs.shift();
  2807. copyEvalArgs.shift();
  2808. }
  2809. continue;
  2810. }
  2811. next = true;
  2812. break;
  2813. }
  2814. copyArgs.shift();
  2815. copyEvalArgs.shift();
  2816. if (argument === null || argument === undefined) {
  2817. argument = {};
  2818. }
  2819. if (!argument.__variable__) {
  2820. argument.__variable__ = this._variable();
  2821. }
  2822. const inputs = [];
  2823. inputs.push({ id: argument.__variable__ });
  2824. referencedParameters.push(argument);
  2825. node.inputs.push(inputs);
  2826. break;
  2827. }
  2828. case 'tensor[]': {
  2829. const argument = copyEvalArgs[0];
  2830. if (!Array.isArray(argument) || !argument.every((item) => pytorch.Utility.isTensor(item) || item === null)) {
  2831. if (parameter.optional) {
  2832. continue;
  2833. }
  2834. next = true;
  2835. break;
  2836. }
  2837. copyArgs.shift();
  2838. copyEvalArgs.shift();
  2839. const inputs = [];
  2840. for (let item of argument) {
  2841. if (item === null) {
  2842. item = {};
  2843. }
  2844. if (!item.__variable__) {
  2845. item.__variable__ = this._variable();
  2846. }
  2847. inputs.push({ id: item.__variable__ });
  2848. referencedParameters.push(item);
  2849. }
  2850. node.inputs.push(inputs);
  2851. break;
  2852. }
  2853. default: {
  2854. const arg = copyArgs[0];
  2855. const value = copyEvalArgs[0];
  2856. if (!pytorch.Utility.isType(value, parameter.type)) {
  2857. if (parameter.optional) {
  2858. continue;
  2859. }
  2860. next = true;
  2861. break;
  2862. }
  2863. if (arg.type !== '=') {
  2864. copyArgs.shift();
  2865. copyEvalArgs.shift();
  2866. node.attributes.push({ name: parameter.name, value: value });
  2867. }
  2868. else {
  2869. throw new pytorch.Error('Expected named argument.');
  2870. }
  2871. break;
  2872. }
  2873. }
  2874. if (next) {
  2875. break;
  2876. }
  2877. }
  2878. if (next) {
  2879. continue;
  2880. }
  2881. const result = [];
  2882. for (const paramter of schema.outputs) {
  2883. switch (paramter.type) {
  2884. case 'tensor': {
  2885. const parameter = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
  2886. switch (type) {
  2887. case 'torch.cat':
  2888. case 'torch.conv2d':
  2889. case 'torch.dropout':
  2890. case 'torch.flatten':
  2891. case 'torch.max_pool2d':
  2892. case 'torch.quantize_per_tensor':
  2893. case 'torch.relu_':
  2894. case 'torch.hardtanh_':
  2895. case 'torch.slice': {
  2896. parameter.size = [ NaN, NaN, NaN, NaN ];
  2897. break;
  2898. }
  2899. case 'torch.conv3d': {
  2900. parameter.size = [ NaN, NaN, NaN, NaN, NaN ];
  2901. break;
  2902. }
  2903. case 'torch.embedding': {
  2904. parameter.size = [ NaN, NaN, NaN ];
  2905. break;
  2906. }
  2907. case 'torch.ones':
  2908. case 'torch.zeros':
  2909. case 'torch.zeros_like': {
  2910. parameter.size = this.expression(args[0], context);
  2911. break;
  2912. }
  2913. }
  2914. parameter.__variable__ = this._variable();
  2915. result.push(parameter);
  2916. node.outputs.push([ { id: parameter.__variable__ } ]);
  2917. break;
  2918. }
  2919. case 'tensor[]': {
  2920. let count = 1;
  2921. switch (type) {
  2922. case 'torch.chunk':
  2923. count = node.attributes.filter((attribute) => attribute.name == 'chunks')[0].value;
  2924. break;
  2925. }
  2926. const tensors = [];
  2927. const outputs = [];
  2928. for (let i = 0; i < count; i ++) {
  2929. const tensor = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
  2930. tensor.__variable__ = this._variable();
  2931. tensors.push(tensor);
  2932. outputs.push({ id: tensor.__variable__ });
  2933. }
  2934. result.push(tensors);
  2935. node.outputs.push(outputs);
  2936. break;
  2937. }
  2938. default: {
  2939. if (!outputTypes || schema.outputs.length !== 1 || schema.outputs[0].type !== outputTypes[0]) {
  2940. next = true;
  2941. break;
  2942. }
  2943. const tensor = { __module__: 'torch', __name__: 'Tensor', __origin__: 'invoke-output-' + type };
  2944. tensor.__variable__ = this._variable();
  2945. result.push(tensor);
  2946. node.outputs.push([ { id: tensor.__variable__ } ]);
  2947. break;
  2948. }
  2949. }
  2950. }
  2951. if (next) {
  2952. continue;
  2953. }
  2954. for (const parameter of referencedParameters) {
  2955. parameter.__count__ = (parameter.__count__ || 0) + 1;
  2956. }
  2957. this._nodes.push(node);
  2958. if (result.length > 1) {
  2959. return result;
  2960. }
  2961. return result[0];
  2962. }
  2963. }
  2964. }
  2965. return super.call(target, name, args, context);
  2966. }
  2967. _variable() {
  2968. this._variableIndex++;
  2969. return this._variableIndex.toString();
  2970. }
  2971. };
  2972. pytorch.ScalarType = {
  2973. uint8: 0, int8: 1, int16: 2, int32: 3, int64: 4,
  2974. float16: 5, float32: 6, float64: 7,
  2975. complex32: 8, complex64: 9, complex128: 10,
  2976. boolean: 11,
  2977. qint8: 12, quint8: 13, qint32: 14, bfloat16: 15
  2978. };
  2979. pytorch.MemoryFormat = {
  2980. Contiguous: 0, Preserve: 1, ChannelsLast: 2, ChannelsLast3d: 3
  2981. };
  2982. pytorch.Layout = {
  2983. Strided: 0, Sparse: 1, Mkldnn: 2
  2984. };
  2985. pytorch.Utility = class {
  2986. static target(expression) {
  2987. if (expression.type == 'id') {
  2988. return expression.value;
  2989. }
  2990. if (expression.type == '.') {
  2991. return pytorch.Utility.target(expression.target) + '.' + pytorch.Utility.target(expression.member);
  2992. }
  2993. return null;
  2994. }
  2995. static isTensor(obj) {
  2996. return obj && (obj.__module__ === 'torch' || obj.__module__ === 'torch.cuda') && obj.__name__ && obj.__name__.endsWith('Tensor');
  2997. }
  2998. static isType(obj, type) {
  2999. switch (type) {
  3000. case 'tensor':
  3001. return !Array.isArray(obj) && (pytorch.Utility.isTensor(obj) || obj === null);
  3002. case 'tensor[]':
  3003. return Array.isArray(obj) && obj.length > 0 && obj.every((tensor) => pytorch.Utility.isTensor(tensor) || tensor === null);
  3004. case 'boolean':
  3005. return obj === true || obj === false;
  3006. case 'int64':
  3007. return Number.isInteger(obj) || isNaN(obj);
  3008. case 'int64[]':
  3009. return Array.isArray(obj) && obj.every((item) => Number.isInteger(item) || Number.isNaN(item) || item === undefined);
  3010. case 'float32':
  3011. case 'float64':
  3012. return obj !== null && obj !== Object(obj);
  3013. case 'Layout':
  3014. case 'ScalarType':
  3015. case 'MemoryFormat':
  3016. return Number.isInteger(obj);
  3017. case 'Device':
  3018. return obj === null || obj === Object(obj);
  3019. case 'scalar':
  3020. return obj !== null || obj !== Object(obj);
  3021. }
  3022. return true;
  3023. }
  3024. static findRootModule(root) {
  3025. const candidates = [ root, root.model, root.net ];
  3026. for (const obj of candidates) {
  3027. if (obj && obj._modules) {
  3028. return obj;
  3029. }
  3030. }
  3031. return null;
  3032. }
  3033. static _findStateDict(root) {
  3034. if (!root) {
  3035. return null;
  3036. }
  3037. if (root.encoder && Array.isArray(root.encoder) &&
  3038. root.decoder && Array.isArray(root.decoder) && !root.state_dict) {
  3039. root = root.encoder.concat(root.decoder);
  3040. }
  3041. if (root instanceof Map) {
  3042. const obj = {};
  3043. for (const pair of root) {
  3044. const key = pair[0];
  3045. const value = pair[1];
  3046. obj[key] = value;
  3047. }
  3048. root = obj;
  3049. }
  3050. const candidates = [
  3051. root.state_dict, root.state,
  3052. root.model_state, root.model, root.model_state_dict, root.net_dict,
  3053. root.params, root.generator, root.discriminator, root.g_state,
  3054. root.network, root.net, root.netG, root.net_states,
  3055. root.state_dict_stylepredictor, root.state_dict_ghiasi,
  3056. root
  3057. ];
  3058. for (const dict of candidates) {
  3059. let state_dict = null;
  3060. state_dict = state_dict || pytorch.Utility._convertStateDictList(dict);
  3061. state_dict = state_dict || pytorch.Utility._convertStateDictMap(dict);
  3062. state_dict = state_dict || pytorch.Utility._convertStateDictGroupMap(dict);
  3063. if (state_dict) {
  3064. return state_dict;
  3065. }
  3066. }
  3067. return null;
  3068. }
  3069. static _convertStateDictList(list) {
  3070. if (list && Array.isArray(list) && list.every((obj) => obj.__module__ && obj.__name__ && Object.keys(obj).filter((key) => pytorch.Utility.isTensor(obj[key]).length > 0))) {
  3071. const layers = [];
  3072. for (const obj of list) {
  3073. const layer = { type: obj.__module__ + '.' + obj.__name__, states: [], attributes: [] };
  3074. for (const key of Object.keys(obj)) {
  3075. const value = obj[key];
  3076. if (pytorch.Utility.isTensor(value)) {
  3077. layer.states.push({ name: key, arguments: [ { id: '', value: value } ] });
  3078. }
  3079. else {
  3080. layer.attributes.push({ name: key, value: value });
  3081. }
  3082. }
  3083. layers.push(layer);
  3084. }
  3085. return layers;
  3086. }
  3087. if (list && !Array.isArray(list) && !(list instanceof Map)) {
  3088. list = new Map(Object.keys(list).map((key) => [ key, list[key] ]));
  3089. }
  3090. if (list && list instanceof Map) {
  3091. for (const item of list) {
  3092. const key = item[0];
  3093. const value = item[1];
  3094. if (!key || !value) {
  3095. return null;
  3096. }
  3097. if (pytorch.Utility.isTensor(value)) {
  3098. continue;
  3099. }
  3100. if (key.endsWith('._packed_params.dtype')) {
  3101. continue;
  3102. }
  3103. if (key.endsWith('._packed_params._packed_params') && Array.isArray(value) && value.every((item) => pytorch.Utility.isTensor(item))) {
  3104. continue;
  3105. }
  3106. return null;
  3107. }
  3108. const layers = new Map();
  3109. for (const item of list) {
  3110. const key = item[0];
  3111. const value = item[1];
  3112. if (value !== null) {
  3113. let layerName = '';
  3114. let parameter = '';
  3115. if (key.endsWith('_packed_params.dtype')) {
  3116. parameter = '_packed_params.dtype';
  3117. layerName = key.substring(0, key.length - parameter.length - 1);
  3118. }
  3119. else if (key.endsWith('_packed_params._packed_params') && Array.isArray(value)) {
  3120. parameter = '_packed_params._packed_params';
  3121. layerName = key.substring(0, key.length - parameter.length - 1);
  3122. }
  3123. else {
  3124. let split = key.split('.');
  3125. if (split.length < 2) {
  3126. split = [ '', split[0] ];
  3127. }
  3128. parameter = split.pop();
  3129. layerName = split.join('.');
  3130. }
  3131. if (!layers.has(layerName)) {
  3132. layers.set(layerName, { name: layerName, states: [], attributes: [] });
  3133. }
  3134. const layer = layers.get(layerName);
  3135. switch (parameter) {
  3136. case '_packed_params.dtype':
  3137. layer.attributes.push({ name: parameter, value: value });
  3138. break;
  3139. case '_packed_params._packed_params':
  3140. layer.states.push({ name: parameter, arguments: value.map((item) => { return { id: '', value: item }; }) });
  3141. break;
  3142. default:
  3143. layer.states.push({ name: parameter, arguments: [ { id: key, value: value } ] });
  3144. if (layer.name == '' && layer.states.length > 4) {
  3145. return null;
  3146. }
  3147. break;
  3148. }
  3149. }
  3150. }
  3151. return layers.values();
  3152. }
  3153. return null;
  3154. }
  3155. static _convertStateDictMap(obj) {
  3156. if (!obj || Array.isArray(obj)) {
  3157. return null;
  3158. }
  3159. const state_dict = [];
  3160. const state_map = {};
  3161. for (const key in obj) {
  3162. const split = key.split('.');
  3163. if (split.length < 1) {
  3164. return null;
  3165. }
  3166. const state = {};
  3167. state.id = key;
  3168. state.name = split.pop();
  3169. state.value = obj[key];
  3170. if (state.value && state.value.__module__ === 'torch.nn.parameter' && state.value.__name__ === 'Parameter') {
  3171. if (pytorch.Utility.isTensor(state.value.data)) {
  3172. state.value = state.value.data;
  3173. }
  3174. }
  3175. if (!pytorch.Utility.isTensor(state.value)) {
  3176. return null;
  3177. }
  3178. const state_group_name = split.join('.');
  3179. let state_group = state_map[state_group_name];
  3180. if (!state_group) {
  3181. state_group = {};
  3182. state_group.name = state_group_name;
  3183. state_group.states = [];
  3184. state_map[state_group_name] = state_group;
  3185. state_dict.push(state_group);
  3186. }
  3187. state_group.states.push({ name: state.name, arguments: [ state ] });
  3188. }
  3189. return state_dict;
  3190. }
  3191. static _convertStateDictGroupMap(obj) {
  3192. if (!obj || Array.isArray(obj)) {
  3193. return null;
  3194. }
  3195. const state_dict = [];
  3196. const state_map = {};
  3197. for (const state_group_name in obj) {
  3198. let state_group = state_map[state_group_name];
  3199. if (!state_group) {
  3200. state_group = {};
  3201. state_group.name = state_group_name;
  3202. state_group.states = [];
  3203. state_group.attributes = [];
  3204. state_map[state_group_name] = state_group;
  3205. state_dict.push(state_group);
  3206. }
  3207. const item = obj[state_group_name];
  3208. if (!item) {
  3209. return null;
  3210. }
  3211. if (item instanceof Map) {
  3212. for (const pair of item) {
  3213. const key = pair[0];
  3214. const value = pair[1];
  3215. if (!key) {
  3216. return null;
  3217. }
  3218. if (value && !pytorch.Utility.isTensor(value)) {
  3219. return null;
  3220. }
  3221. const argument = { id: state_group_name + '.' + key, value: value };
  3222. state_group.states.push({ name: key, arguments: [ argument ] });
  3223. }
  3224. }
  3225. else if (item instanceof Uint8Array) {
  3226. return null;
  3227. }
  3228. else if (Object(item) === item) {
  3229. let hasTensors = false;
  3230. for (const key in item) {
  3231. const value = item[key];
  3232. if (pytorch.Utility.isTensor(value)) {
  3233. const argument = { id: state_group_name + '.' + key, value: value };
  3234. state_group.states.push({ name: key, arguments: [ argument ] });
  3235. hasTensors = true;
  3236. }
  3237. else if (value !== Object(value)) {
  3238. state_group.attributes.push({ name: key, value: value });
  3239. }
  3240. else if (value && value.data && value.__module__ === 'torch.nn.parameter' && value.__name__ === 'Parameter') {
  3241. const argument = { id: state_group_name + '.' + key, value: value.data };
  3242. state_group.states.push({ name: key, arguments: [ argument ] });
  3243. hasTensors = true;
  3244. }
  3245. else {
  3246. return null;
  3247. }
  3248. }
  3249. if (!hasTensors) {
  3250. return null;
  3251. }
  3252. }
  3253. else {
  3254. return null;
  3255. }
  3256. }
  3257. return state_dict;
  3258. }
  3259. static readInt32(buffer) {
  3260. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  3261. return view.getInt32(0, true);
  3262. }
  3263. static readInt64(buffer) {
  3264. const view = new DataView(buffer.buffer, buffer.byteOffset, buffer.byteLength);
  3265. return view.getInt64(0, true).toNumber();
  3266. }
  3267. };
  3268. if (typeof module !== 'undefined' && typeof module.exports === 'object') {
  3269. module.exports.ModelFactory = pytorch.ModelFactory;
  3270. }