modeling_utils.py 137 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """TF general model utils."""
  17. import functools
  18. import logging
  19. import os
  20. import h5py
  21. import numpy as np
  22. import tensorflow as tf
  23. from tensorflow.python.keras.saving import hdf5_format
  24. from configuration_utils import PretrainedConfig, BertConfig
  25. from file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
  26. from file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
  27. from tokenization_utils import BatchEncoding
  28. from utils import log
  29. class TFModelUtilsMixin:
  30. """
  31. A few utilities for `tf.keras.Model`s, to be used as a mixin.
  32. """
  33. def num_parameters(self, only_trainable: bool = False) -> int:
  34. """
  35. Get number of (optionally, trainable) parameters in the model.
  36. """
  37. if only_trainable:
  38. return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
  39. else:
  40. return self.count_params()
  41. def keras_serializable(cls):
  42. """
  43. Decorate a Keras Layer class to support Keras serialization.
  44. This is done by:
  45. 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
  46. serialization time
  47. 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
  48. convert it to a config object for the actual layer initializer
  49. 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
  50. not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
  51. :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
  52. `TF*MainLayer` class in this project)
  53. :return: the same class object, with modifications for Keras deserialization.
  54. """
  55. initializer = cls.__init__
  56. config_class = getattr(cls, "config_class", None)
  57. if config_class is None:
  58. raise AttributeError("Must set `config_class` to use @keras_serializable")
  59. @functools.wraps(initializer)
  60. def wrapped_init(self, *args, **kwargs):
  61. transformers_config = kwargs.pop("transformers_config", None)
  62. config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
  63. if config is not None and transformers_config is not None:
  64. raise ValueError("Must pass either `config` or `transformers_config`, not both")
  65. elif config is not None:
  66. # normal layer construction, call with unchanged args (config is already in there)
  67. initializer(self, *args, **kwargs)
  68. elif transformers_config is not None:
  69. # Keras deserialization, convert dict to config
  70. config = config_class.from_dict(transformers_config)
  71. initializer(self, config, *args, **kwargs)
  72. else:
  73. raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
  74. self._transformers_config = config
  75. cls.__init__ = wrapped_init
  76. if not hasattr(cls, "get_config"):
  77. raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
  78. if hasattr(cls.get_config, "_is_default"):
  79. def get_config(self):
  80. cfg = super(cls, self).get_config()
  81. cfg["transformers_config"] = self._transformers_config.to_dict()
  82. return cfg
  83. cls.get_config = get_config
  84. cls._keras_serializable = True
  85. if hasattr(tf.keras.utils, "register_keras_serializable"):
  86. cls = tf.keras.utils.register_keras_serializable()(cls)
  87. return cls
  88. class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
  89. r""" Base class for all TF models.
  90. :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
  91. as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
  92. Class attributes (overridden by derived classes):
  93. - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
  94. - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
  95. - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
  96. - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
  97. - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
  98. - ``path``: a path (string) to the TensorFlow checkpoint.
  99. - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
  100. """
  101. config_class = None
  102. pretrained_model_archive_map = {}
  103. base_model_prefix = ""
  104. @property
  105. def dummy_inputs(self):
  106. """ Dummy inputs to build the network.
  107. Returns:
  108. tf.Tensor with dummy inputs
  109. """
  110. return {"input_ids": tf.constant(DUMMY_INPUTS)}
  111. def __init__(self, config, *inputs, **kwargs):
  112. super().__init__(*inputs, **kwargs)
  113. if not isinstance(config, PretrainedConfig):
  114. raise ValueError(
  115. "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
  116. "To create a model from a pretrained model use "
  117. "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
  118. self.__class__.__name__, self.__class__.__name__
  119. )
  120. )
  121. # Save config in model
  122. self.config = config
  123. def get_input_embeddings(self):
  124. """
  125. Returns the model's input embeddings.
  126. Returns:
  127. :obj:`tf.keras.layers.Layer`:
  128. A torch module mapping vocabulary to hidden states.
  129. """
  130. base_model = getattr(self, self.base_model_prefix, self)
  131. if base_model is not self:
  132. return base_model.get_input_embeddings()
  133. else:
  134. raise NotImplementedError
  135. def get_output_embeddings(self):
  136. """
  137. Returns the model's output embeddings.
  138. Returns:
  139. :obj:`tf.keras.layers.Layer`:
  140. A torch module mapping hidden states to vocabulary.
  141. """
  142. return None # Overwrite for models with output embeddings
  143. def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
  144. """ Build a resized Embedding Variable from a provided token Embedding Module.
  145. Increasing the size will add newly initialized vectors at the end
  146. Reducing the size will remove vectors from the end
  147. Args:
  148. new_num_tokens: (`optional`) int
  149. New number of tokens in the embedding matrix.
  150. Increasing the size will add newly initialized vectors at the end
  151. Reducing the size will remove vectors from the end
  152. If not provided or None: return the provided token Embedding Module.
  153. Return: ``tf.Variable``
  154. Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
  155. """
  156. # if new_num_tokens is None:
  157. # return old_embeddings
  158. # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  159. # if old_num_tokens == new_num_tokens:
  160. # return old_embeddings
  161. # # Build new embeddings
  162. # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
  163. # new_embeddings.to(old_embeddings.weight.device)
  164. # # initialize all new embeddings (in particular added tokens)
  165. # self._init_weights(new_embeddings)
  166. # # Copy token embeddings from the previous weights
  167. # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  168. # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
  169. # return new_embeddings
  170. def resize_token_embeddings(self, new_num_tokens=None):
  171. """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
  172. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  173. Arguments:
  174. new_num_tokens: (`optional`) int:
  175. New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
  176. If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
  177. Return: ``tf.Variable``
  178. Pointer to the input tokens Embeddings Module of the model
  179. """
  180. raise NotImplementedError
  181. def prune_heads(self, heads_to_prune):
  182. """ Prunes heads of the base model.
  183. Arguments:
  184. heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
  185. """
  186. raise NotImplementedError
  187. def save_pretrained(self, save_directory):
  188. """ Save a model and its configuration file to a directory, so that it
  189. can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.
  190. """
  191. if os.path.isfile(save_directory):
  192. log("Provided path ({}) should be a directory, not a file".format(save_directory))
  193. return
  194. os.makedirs(save_directory, exist_ok=True)
  195. # Save configuration file
  196. self.config.save_pretrained(save_directory)
  197. # If we save using the predefined names, we can load using `from_pretrained`
  198. output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
  199. self.save_weights(output_model_file)
  200. with h5py.File(output_model_file, "r") as f:
  201. if "layer_names" not in f.attrs and "model_weights" in f:
  202. f = f["model_weights"]
  203. hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
  204. log(f"Model weights saved in {output_model_file}: {hdf5_layer_names}")
  205. @classmethod
  206. def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
  207. r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
  208. The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
  209. It is up to you to train those weights with a downstream fine-tuning task.
  210. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
  211. Parameters:
  212. pretrained_model_name_or_path: either:
  213. - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
  214. - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
  215. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
  216. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
  217. model_args: (`optional`) Sequence of positional arguments:
  218. All remaning positional arguments will be passed to the underlying model's ``__init__`` method
  219. config: (`optional`) one of:
  220. - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
  221. - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
  222. Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
  223. - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
  224. - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
  225. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
  226. from_pt: (`optional`) boolean, default False:
  227. Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).
  228. cache_dir: (`optional`) string:
  229. Path to a directory in which a downloaded pre-trained model
  230. configuration should be cached if the standard cache should not be used.
  231. force_download: (`optional`) boolean, default False:
  232. Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
  233. resume_download: (`optional`) boolean, default False:
  234. Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
  235. proxies: (`optional`) dict, default None:
  236. A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
  237. The proxies are used on each request.
  238. output_loading_info: (`optional`) boolean:
  239. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
  240. kwargs: (`optional`) Remaining dictionary of keyword arguments:
  241. Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
  242. - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
  243. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
  244. Examples::
  245. # For example purposes. Not runnable.
  246. model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
  247. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
  248. model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
  249. assert model.config.output_attention == True
  250. # Loading from a TF checkpoint file instead of a PyTorch model (slower)
  251. config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
  252. model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)
  253. """
  254. config = kwargs.pop("config", None)
  255. cache_dir = kwargs.pop("cache_dir", None)
  256. from_pt = kwargs.pop("from_pt", False)
  257. force_download = kwargs.pop("force_download", False)
  258. resume_download = kwargs.pop("resume_download", False)
  259. proxies = kwargs.pop("proxies", None)
  260. output_loading_info = kwargs.pop("output_loading_info", False)
  261. # Load config if we don't provide a configuration
  262. if not isinstance(config, PretrainedConfig):
  263. config_path = config if config is not None else pretrained_model_name_or_path
  264. config, model_kwargs = cls.config_class.from_pretrained(
  265. config_path,
  266. *model_args,
  267. cache_dir=cache_dir,
  268. return_unused_kwargs=True,
  269. force_download=force_download,
  270. resume_download=resume_download,
  271. **kwargs,
  272. )
  273. else:
  274. model_kwargs = kwargs
  275. # Load model
  276. if pretrained_model_name_or_path is not None:
  277. if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
  278. archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
  279. elif os.path.isdir(pretrained_model_name_or_path):
  280. if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
  281. # Load from a TF 2.0 checkpoint
  282. archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
  283. elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
  284. # Load from a PyTorch checkpoint
  285. archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
  286. else:
  287. raise EnvironmentError(
  288. "Error no file named {} found in directory {} or `from_pt` set to False".format(
  289. [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
  290. )
  291. )
  292. elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
  293. archive_file = pretrained_model_name_or_path
  294. elif os.path.isfile(pretrained_model_name_or_path + ".index"):
  295. archive_file = pretrained_model_name_or_path + ".index"
  296. else:
  297. archive_file = hf_bucket_url(
  298. pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)
  299. )
  300. # redirect to the cache, if necessary
  301. try:
  302. resolved_archive_file = cached_path(
  303. archive_file,
  304. cache_dir=cache_dir,
  305. force_download=force_download,
  306. resume_download=resume_download,
  307. proxies=proxies,
  308. )
  309. except EnvironmentError as e:
  310. if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
  311. log("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
  312. else:
  313. log(
  314. "Model name '{}' was not found in model name list ({}). "
  315. "We assumed '{}' was a path or url but couldn't find any file "
  316. "associated to this path or url.".format(
  317. pretrained_model_name_or_path,
  318. ", ".join(cls.pretrained_model_archive_map.keys()),
  319. archive_file,
  320. )
  321. )
  322. raise e
  323. if resolved_archive_file == archive_file:
  324. log("loading weights file {}".format(archive_file))
  325. else:
  326. log("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
  327. else:
  328. resolved_archive_file = None
  329. # Instantiate model.
  330. model = cls(config, *model_args, **model_kwargs)
  331. if from_pt:
  332. # Load from a PyTorch checkpoint
  333. raise NotImplementedError
  334. # return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
  335. model(model.dummy_inputs, training=False) # build the network with dummy inputs
  336. assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
  337. # 'by_name' allow us to do transfer learning by skipping/adding layers
  338. # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
  339. try:
  340. model.load_weights(resolved_archive_file, by_name=True)
  341. except OSError:
  342. raise OSError(
  343. "Unable to load weights from h5 file. "
  344. "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
  345. )
  346. model(model.dummy_inputs, training=False) # Make sure restore ops are run
  347. # Check if the models are the same to output loading information
  348. with h5py.File(resolved_archive_file, "r") as f:
  349. if "layer_names" not in f.attrs and "model_weights" in f:
  350. f = f["model_weights"]
  351. hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
  352. model_layer_names = set(layer.name for layer in model.layers)
  353. missing_keys = list(model_layer_names - hdf5_layer_names)
  354. unexpected_keys = list(hdf5_layer_names - model_layer_names)
  355. error_msgs = []
  356. if len(unexpected_keys) > 0:
  357. log(
  358. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
  359. f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
  360. )
  361. else:
  362. log(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
  363. if len(missing_keys) > 0:
  364. log(
  365. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
  366. f"and are newly initialized: {missing_keys}\n"
  367. )
  368. else:
  369. log(
  370. f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
  371. f"If your task is similar to the task the model of the ckeckpoint was trained on, "
  372. f"you can already use {model.__class__.__name__} for predictions without further training."
  373. )
  374. if len(error_msgs) > 0:
  375. raise RuntimeError(
  376. "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
  377. )
  378. if output_loading_info:
  379. loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
  380. return model, loading_info
  381. return model
  382. def prepare_inputs_for_generation(self, inputs, **kwargs):
  383. return {"inputs": inputs}
  384. def _do_output_past(self, outputs):
  385. has_output_past = hasattr(self.config, "output_past") and self.config.output_past
  386. has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
  387. if has_output_past and not has_mem_len and len(outputs) > 1:
  388. return True
  389. elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
  390. return True
  391. return False
  392. def generate(
  393. self,
  394. input_ids=None,
  395. max_length=None,
  396. min_length=None,
  397. do_sample=None,
  398. early_stopping=None,
  399. num_beams=None,
  400. temperature=None,
  401. top_k=None,
  402. top_p=None,
  403. repetition_penalty=None,
  404. bad_words_ids=None,
  405. bos_token_id=None,
  406. pad_token_id=None,
  407. eos_token_id=None,
  408. length_penalty=None,
  409. no_repeat_ngram_size=None,
  410. num_return_sequences=None,
  411. attention_mask=None,
  412. decoder_start_token_id=None,
  413. ):
  414. r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
  415. and beam-search.
  416. Adapted in part from `Facebook's XLM beam search code`_.
  417. .. _`Facebook's XLM beam search code`:
  418. https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
  419. Parameters:
  420. input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)`
  421. The sequence used as a prompt for the generation. If `None` the method initializes
  422. it as an empty `torch.LongTensor` of shape `(1,)`.
  423. max_length: (`optional`) int
  424. The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
  425. min_length: (`optional`) int
  426. The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
  427. do_sample: (`optional`) bool
  428. If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
  429. early_stopping: (`optional`) bool
  430. if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
  431. num_beams: (`optional`) int
  432. Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
  433. temperature: (`optional`) float
  434. The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
  435. top_k: (`optional`) int
  436. The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
  437. top_p: (`optional`) float
  438. The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
  439. repetition_penalty: (`optional`) float
  440. The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
  441. bos_token_id: (`optional`) int
  442. Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
  443. pad_token_id: (`optional`) int
  444. Pad token. Defaults to pad_token_id as defined in the models config.
  445. eos_token_id: (`optional`) int
  446. EOS token. Defaults to eos_token_id as defined in the models config.
  447. length_penalty: (`optional`) float
  448. Exponential penalty to the length. Default to 1.
  449. no_repeat_ngram_size: (`optional`) int
  450. If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
  451. bad_words_ids: (`optional`) list of lists of int
  452. `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
  453. num_return_sequences: (`optional`) int
  454. The number of independently computed returned sequences for each element in the batch. Default to 1.
  455. attention_mask (`optional`) obj: `tf.Tensor` with `dtype=tf.int32` of same shape as `input_ids`
  456. Mask to avoid performing attention on padding token indices.
  457. Mask values selected in ``[0, 1]``:
  458. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
  459. Defaults to `None`.
  460. `What are attention masks? <../glossary.html#attention-mask>`__
  461. decoder_start_token_id=None: (`optional`) int
  462. If an encoder-decoder model starts decoding with a different token than BOS.
  463. Defaults to `None` and is changed to `BOS` later.
  464. Return:
  465. output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
  466. sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
  467. Examples::
  468. tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
  469. model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
  470. outputs = model.generate(max_length=40) # do greedy decoding
  471. print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
  472. tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
  473. model = TFAutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
  474. input_context = 'The dog'
  475. input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
  476. outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
  477. for i in range(3): # 3 output sequences were generated
  478. print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
  479. tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
  480. model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
  481. input_context = 'The dog'
  482. input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
  483. outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
  484. for i in range(3): # 3 output sequences were generated
  485. print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
  486. tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
  487. model = TFAutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
  488. input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
  489. input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
  490. outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
  491. print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
  492. tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
  493. model = TFAutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
  494. input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
  495. bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
  496. input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
  497. outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
  498. """
  499. # We cannot generate if the model does not have a LM head
  500. if self.get_output_embeddings() is None:
  501. raise AttributeError(
  502. "You tried to generate sequences with a model that does not have a LM Head."
  503. "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
  504. )
  505. max_length = max_length if max_length is not None else self.config.max_length
  506. min_length = min_length if min_length is not None else self.config.min_length
  507. do_sample = do_sample if do_sample is not None else self.config.do_sample
  508. early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
  509. num_beams = num_beams if num_beams is not None else self.config.num_beams
  510. temperature = temperature if temperature is not None else self.config.temperature
  511. top_k = top_k if top_k is not None else self.config.top_k
  512. top_p = top_p if top_p is not None else self.config.top_p
  513. repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
  514. bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
  515. pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
  516. eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
  517. length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
  518. no_repeat_ngram_size = (
  519. no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
  520. )
  521. bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
  522. num_return_sequences = (
  523. num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
  524. )
  525. decoder_start_token_id = (
  526. decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
  527. )
  528. if input_ids is not None:
  529. batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
  530. else:
  531. batch_size = 1
  532. assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
  533. assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
  534. assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
  535. assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
  536. assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
  537. assert temperature > 0, "`temperature` should be strictely positive."
  538. assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
  539. assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
  540. assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
  541. assert input_ids is not None or (
  542. isinstance(bos_token_id, int) and bos_token_id >= 0
  543. ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
  544. assert pad_token_id is None or (
  545. isinstance(pad_token_id, int) and (pad_token_id >= 0)
  546. ), "`pad_token_id` should be a positive integer."
  547. assert (eos_token_id is None) or (
  548. isinstance(eos_token_id, int) and (eos_token_id >= 0)
  549. ), "`eos_token_id` should be a positive integer."
  550. assert length_penalty > 0, "`length_penalty` should be strictely positive."
  551. assert (
  552. isinstance(num_return_sequences, int) and num_return_sequences > 0
  553. ), "`num_return_sequences` should be a strictely positive integer."
  554. assert (
  555. bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
  556. ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
  557. if input_ids is None:
  558. assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
  559. "you should either supply a context to complete as `input_ids` input "
  560. "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
  561. )
  562. input_ids = tf.fill((batch_size, 1), bos_token_id)
  563. else:
  564. assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
  565. # not allow to duplicate outputs when greedy decoding
  566. if do_sample is False:
  567. if num_beams == 1:
  568. # no_beam_search greedy generation conditions
  569. assert (
  570. num_return_sequences == 1
  571. ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
  572. else:
  573. # beam_search greedy generation conditions
  574. assert (
  575. num_beams >= num_return_sequences
  576. ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
  577. # create attention mask if necessary
  578. # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
  579. if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
  580. attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
  581. elif attention_mask is None:
  582. attention_mask = tf.ones_like(input_ids)
  583. if pad_token_id is None and eos_token_id is not None:
  584. log(
  585. "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
  586. )
  587. pad_token_id = eos_token_id
  588. # current position and vocab size
  589. cur_len = shape_list(input_ids)[1]
  590. vocab_size = self.config.vocab_size
  591. # set effective batch size and effective batch multiplier according to do_sample
  592. if do_sample:
  593. effective_batch_size = batch_size * num_return_sequences
  594. effective_batch_mult = num_return_sequences
  595. else:
  596. effective_batch_size = batch_size
  597. effective_batch_mult = 1
  598. # Expand input ids if num_beams > 1 or num_return_sequences > 1
  599. if num_return_sequences > 1 or num_beams > 1:
  600. input_ids_len = shape_list(input_ids)[-1]
  601. input_ids = tf.broadcast_to(
  602. tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
  603. )
  604. attention_mask = tf.broadcast_to(
  605. tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
  606. )
  607. input_ids = tf.reshape(
  608. input_ids, (effective_batch_size * num_beams, input_ids_len)
  609. ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
  610. attention_mask = tf.reshape(
  611. attention_mask, (effective_batch_size * num_beams, input_ids_len)
  612. ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
  613. if self.config.is_encoder_decoder:
  614. if decoder_start_token_id is None:
  615. decoder_start_token_id = bos_token_id
  616. assert (
  617. decoder_start_token_id is not None
  618. ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
  619. assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
  620. assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
  621. # get encoder and store encoder outputs
  622. encoder = self.get_encoder()
  623. encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
  624. # create empty decoder_input_ids
  625. input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
  626. cur_len = 1
  627. else:
  628. encoder_outputs = None
  629. cur_len = shape_list(input_ids)[-1]
  630. if num_beams > 1:
  631. output = self._generate_beam_search(
  632. input_ids,
  633. cur_len=cur_len,
  634. max_length=max_length,
  635. min_length=min_length,
  636. do_sample=do_sample,
  637. early_stopping=early_stopping,
  638. temperature=temperature,
  639. top_k=top_k,
  640. top_p=top_p,
  641. repetition_penalty=repetition_penalty,
  642. no_repeat_ngram_size=no_repeat_ngram_size,
  643. bad_words_ids=bad_words_ids,
  644. bos_token_id=bos_token_id,
  645. pad_token_id=pad_token_id,
  646. eos_token_id=eos_token_id,
  647. decoder_start_token_id=decoder_start_token_id,
  648. batch_size=effective_batch_size,
  649. num_return_sequences=num_return_sequences,
  650. length_penalty=length_penalty,
  651. num_beams=num_beams,
  652. vocab_size=vocab_size,
  653. encoder_outputs=encoder_outputs,
  654. attention_mask=attention_mask,
  655. )
  656. else:
  657. output = self._generate_no_beam_search(
  658. input_ids,
  659. cur_len=cur_len,
  660. max_length=max_length,
  661. min_length=min_length,
  662. do_sample=do_sample,
  663. temperature=temperature,
  664. top_k=top_k,
  665. top_p=top_p,
  666. repetition_penalty=repetition_penalty,
  667. no_repeat_ngram_size=no_repeat_ngram_size,
  668. bad_words_ids=bad_words_ids,
  669. bos_token_id=bos_token_id,
  670. pad_token_id=pad_token_id,
  671. eos_token_id=eos_token_id,
  672. decoder_start_token_id=decoder_start_token_id,
  673. batch_size=effective_batch_size,
  674. vocab_size=vocab_size,
  675. encoder_outputs=encoder_outputs,
  676. attention_mask=attention_mask,
  677. )
  678. return output
  679. def _generate_no_beam_search(
  680. self,
  681. input_ids,
  682. cur_len,
  683. max_length,
  684. min_length,
  685. do_sample,
  686. temperature,
  687. top_k,
  688. top_p,
  689. repetition_penalty,
  690. no_repeat_ngram_size,
  691. bad_words_ids,
  692. bos_token_id,
  693. pad_token_id,
  694. eos_token_id,
  695. decoder_start_token_id,
  696. batch_size,
  697. vocab_size,
  698. encoder_outputs,
  699. attention_mask,
  700. ):
  701. """ Generate sequences for each example without beam search (num_beams == 1).
  702. All returned sequence are generated independantly.
  703. """
  704. # length of generated sentences / unfinished sentences
  705. unfinished_sents = tf.ones_like(input_ids[:, 0])
  706. sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
  707. past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
  708. while cur_len < max_length:
  709. model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
  710. outputs = self(**model_inputs)
  711. next_token_logits = outputs[0][:, -1, :]
  712. # if model has past, then set the past variable to speed up decoding
  713. if self._do_output_past(outputs):
  714. past = outputs[1]
  715. # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
  716. if repetition_penalty != 1.0:
  717. next_token_logits_penalties = _create_next_token_logits_penalties(
  718. input_ids, next_token_logits, repetition_penalty
  719. )
  720. next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
  721. if no_repeat_ngram_size > 0:
  722. # calculate a list of banned tokens to prevent repetitively generating the same ngrams
  723. # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
  724. banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
  725. # create banned_tokens boolean mask
  726. banned_tokens_indices_mask = []
  727. for banned_tokens_slice in banned_tokens:
  728. banned_tokens_indices_mask.append(
  729. [True if token in banned_tokens_slice else False for token in range(vocab_size)]
  730. )
  731. next_token_logits = set_tensor_by_indices_to_value(
  732. next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
  733. )
  734. if bad_words_ids is not None:
  735. # calculate a list of banned tokens according to bad words
  736. banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
  737. banned_tokens_indices_mask = []
  738. for banned_tokens_slice in banned_tokens:
  739. banned_tokens_indices_mask.append(
  740. [True if token in banned_tokens_slice else False for token in range(vocab_size)]
  741. )
  742. next_token_logits = set_tensor_by_indices_to_value(
  743. next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
  744. )
  745. # set eos token prob to zero if min_length is not reached
  746. if eos_token_id is not None and cur_len < min_length:
  747. # create eos_token_id boolean mask
  748. is_token_logit_eos_token = tf.convert_to_tensor(
  749. [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
  750. )
  751. eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
  752. next_token_logits = set_tensor_by_indices_to_value(
  753. next_token_logits, eos_token_indices_mask, -float("inf")
  754. )
  755. if do_sample:
  756. # Temperature (higher temperature => more likely to sample low probability tokens)
  757. if temperature != 1.0:
  758. next_token_logits = next_token_logits / temperature
  759. # Top-p/top-k filtering
  760. next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
  761. # Sample
  762. next_token = tf.squeeze(
  763. tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
  764. )
  765. else:
  766. # Greedy decoding
  767. next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
  768. # update generations and finished sentences
  769. if eos_token_id is not None:
  770. # pad finished sentences if eos_token_id exist
  771. tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
  772. else:
  773. tokens_to_add = next_token
  774. input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
  775. if eos_token_id is not None:
  776. eos_in_sents = tokens_to_add == eos_token_id
  777. # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
  778. is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
  779. unfinished_sents, tf.cast(eos_in_sents, tf.int32)
  780. )
  781. sent_lengths = (
  782. sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
  783. + cur_len * is_sents_unfinished_and_token_to_add_is_eos
  784. )
  785. # unfinished_sents is set to zero if eos in sentence
  786. unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
  787. # stop when there is a </s> in each sentence, or if we exceed the maximul length
  788. if tf.math.reduce_max(unfinished_sents) == 0:
  789. break
  790. # extend attention_mask for new generated input if only decoder
  791. if self.config.is_encoder_decoder is False:
  792. attention_mask = tf.concat(
  793. [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
  794. )
  795. cur_len = cur_len + 1
  796. # if there are different sentences lengths in the batch, some batches have to be padded
  797. min_sent_length = tf.math.reduce_min(sent_lengths)
  798. max_sent_length = tf.math.reduce_max(sent_lengths)
  799. if min_sent_length != max_sent_length:
  800. assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
  801. # finished sents are filled with pad_token
  802. padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
  803. # create length masks for tf.where operation
  804. broad_casted_sent_lengths = tf.broadcast_to(
  805. tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
  806. )
  807. broad_casted_range = tf.transpose(
  808. tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
  809. )
  810. decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
  811. else:
  812. decoded = input_ids
  813. return decoded
  814. def _generate_beam_search(
  815. self,
  816. input_ids,
  817. cur_len,
  818. max_length,
  819. min_length,
  820. do_sample,
  821. early_stopping,
  822. temperature,
  823. top_k,
  824. top_p,
  825. repetition_penalty,
  826. no_repeat_ngram_size,
  827. bad_words_ids,
  828. bos_token_id,
  829. pad_token_id,
  830. decoder_start_token_id,
  831. eos_token_id,
  832. batch_size,
  833. num_return_sequences,
  834. length_penalty,
  835. num_beams,
  836. vocab_size,
  837. encoder_outputs,
  838. attention_mask,
  839. ):
  840. """ Generate sequences for each example with beam search.
  841. """
  842. # generated hypotheses
  843. generated_hyps = [
  844. BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
  845. for _ in range(batch_size)
  846. ]
  847. # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
  848. if do_sample is False:
  849. beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
  850. beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
  851. beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
  852. else:
  853. beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
  854. beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
  855. # cache compute states
  856. past = encoder_outputs
  857. # done sentences
  858. done = [False for _ in range(batch_size)]
  859. while cur_len < max_length:
  860. model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
  861. outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
  862. next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
  863. # if model has past, then set the past variable to speed up decoding
  864. if self._do_output_past(outputs):
  865. past = outputs[1]
  866. # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
  867. if repetition_penalty != 1.0:
  868. next_token_logits_penalties = _create_next_token_logits_penalties(
  869. input_ids, next_token_logits, repetition_penalty
  870. )
  871. next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
  872. # Temperature (higher temperature => more likely to sample low probability tokens)
  873. if temperature != 1.0:
  874. next_token_logits = next_token_logits / temperature
  875. # calculate log softmax score
  876. scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
  877. # set eos token prob to zero if min_length is not reached
  878. if eos_token_id is not None and cur_len < min_length:
  879. # create eos_token_id boolean mask
  880. num_batch_hypotheses = batch_size * num_beams
  881. is_token_logit_eos_token = tf.convert_to_tensor(
  882. [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
  883. )
  884. eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
  885. scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
  886. if no_repeat_ngram_size > 0:
  887. # calculate a list of banned tokens to prevent repetitively generating the same ngrams
  888. # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
  889. num_batch_hypotheses = batch_size * num_beams
  890. banned_tokens = calc_banned_ngram_tokens(
  891. input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
  892. )
  893. # create banned_tokens boolean mask
  894. banned_tokens_indices_mask = []
  895. for banned_tokens_slice in banned_tokens:
  896. banned_tokens_indices_mask.append(
  897. [True if token in banned_tokens_slice else False for token in range(vocab_size)]
  898. )
  899. scores = set_tensor_by_indices_to_value(
  900. scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
  901. )
  902. if bad_words_ids is not None:
  903. # calculate a list of banned tokens according to bad words
  904. banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
  905. banned_tokens_indices_mask = []
  906. for banned_tokens_slice in banned_tokens:
  907. banned_tokens_indices_mask.append(
  908. [True if token in banned_tokens_slice else False for token in range(vocab_size)]
  909. )
  910. scores = set_tensor_by_indices_to_value(
  911. scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
  912. )
  913. assert shape_list(scores) == [batch_size * num_beams, vocab_size]
  914. if do_sample:
  915. _scores = scores + tf.broadcast_to(
  916. beam_scores[:, None], (batch_size * num_beams, vocab_size)
  917. ) # (batch_size * num_beams, vocab_size)
  918. # Top-p/top-k filtering
  919. _scores = tf_top_k_top_p_filtering(
  920. _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
  921. ) # (batch_size * num_beams, vocab_size)
  922. # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
  923. _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
  924. next_tokens = tf.random.categorical(
  925. _scores, dtype=tf.int32, num_samples=2 * num_beams
  926. ) # (batch_size, 2 * num_beams)
  927. # Compute next scores
  928. next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
  929. # sort the sampled vector to make sure that the first num_beams samples are the best
  930. next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
  931. next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
  932. next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
  933. else:
  934. # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
  935. next_scores = scores + tf.broadcast_to(
  936. beam_scores[:, None], (batch_size * num_beams, vocab_size)
  937. ) # (batch_size * num_beams, vocab_size)
  938. # re-organize to group the beam together (we are keeping top hypothesis accross beams)
  939. next_scores = tf.reshape(
  940. next_scores, (batch_size, num_beams * vocab_size)
  941. ) # (batch_size, num_beams * vocab_size)
  942. next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
  943. assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
  944. # next batch beam content
  945. next_batch_beam = []
  946. # for each sentence
  947. for batch_idx in range(batch_size):
  948. # if we are done with this sentence
  949. if done[batch_idx]:
  950. assert (
  951. len(generated_hyps[batch_idx]) >= num_beams
  952. ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
  953. assert (
  954. eos_token_id is not None and pad_token_id is not None
  955. ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
  956. next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
  957. continue
  958. # next sentence beam content
  959. next_sent_beam = []
  960. # next tokens for this sentence
  961. for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
  962. zip(next_tokens[batch_idx], next_scores[batch_idx])
  963. ):
  964. # get beam and token IDs
  965. beam_id = beam_token_id // vocab_size
  966. token_id = beam_token_id % vocab_size
  967. effective_beam_id = batch_idx * num_beams + beam_id
  968. # add to generated hypotheses if end of sentence or last iteration
  969. if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
  970. # if beam_token does not belong to top num_beams tokens, it should not be added
  971. is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
  972. if is_beam_token_worse_than_top_num_beams:
  973. continue
  974. generated_hyps[batch_idx].add(
  975. tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
  976. )
  977. else:
  978. # add next predicted token if it is not eos_token
  979. next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
  980. # the beam for next step is full
  981. if len(next_sent_beam) == num_beams:
  982. break
  983. # Check if were done so that we can save a pad step if all(done)
  984. done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
  985. tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
  986. )
  987. # update next beam content
  988. assert len(next_sent_beam) == num_beams, "Beam should always be full"
  989. next_batch_beam.extend(next_sent_beam)
  990. assert len(next_batch_beam) == num_beams * (batch_idx + 1)
  991. # stop when we are done with each sentence
  992. if all(done):
  993. break
  994. # sanity check / prepare next batch
  995. assert len(next_batch_beam) == batch_size * num_beams
  996. beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
  997. beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
  998. beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
  999. # re-order batch
  1000. input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
  1001. input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
  1002. # re-order internal states
  1003. if past is not None:
  1004. past = self._reorder_cache(past, beam_idx)
  1005. # extend attention_mask for new generated input if only decoder
  1006. if self.config.is_encoder_decoder is False:
  1007. attention_mask = tf.concat(
  1008. [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
  1009. )
  1010. # update current length
  1011. cur_len = cur_len + 1
  1012. # finalize all open beam hypotheses and end to generated hypotheses
  1013. for batch_idx in range(batch_size):
  1014. # Add all open beam hypothesis to generated_hyps
  1015. if done[batch_idx]:
  1016. continue
  1017. # test that beam scores match previously calculated scores if not eos and batch_idx not done
  1018. if eos_token_id is not None and all(
  1019. (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
  1020. ):
  1021. assert tf.reduce_all(
  1022. next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
  1023. ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
  1024. next_scores[:, :num_beams][batch_idx], tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
  1025. )
  1026. # need to add best num_beams hypotheses to generated hyps
  1027. for beam_id in range(num_beams):
  1028. effective_beam_id = batch_idx * num_beams + beam_id
  1029. final_score = beam_scores[effective_beam_id].numpy().item()
  1030. final_tokens = input_ids[effective_beam_id]
  1031. generated_hyps[batch_idx].add(final_tokens, final_score)
  1032. # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
  1033. output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
  1034. output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
  1035. # select the best hypotheses
  1036. sent_lengths_list = []
  1037. best = []
  1038. # retrieve best hypotheses
  1039. for i, hypotheses in enumerate(generated_hyps):
  1040. sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
  1041. for j in range(output_num_return_sequences_per_batch):
  1042. best_hyp = sorted_hyps.pop()[1]
  1043. sent_lengths_list.append(len(best_hyp))
  1044. best.append(best_hyp)
  1045. assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
  1046. output_batch_size, len(best)
  1047. )
  1048. sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
  1049. # shorter batches are filled with pad_token
  1050. if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
  1051. assert pad_token_id is not None, "`Pad_token_id` has to be defined"
  1052. sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
  1053. decoded_list = []
  1054. # fill with hypothesis and eos_token_id if necessary
  1055. for i, hypo in enumerate(best):
  1056. assert sent_lengths[i] == shape_list(hypo)[0]
  1057. # if sent_length is max_len do not pad
  1058. if sent_lengths[i] == sent_max_len:
  1059. decoded_slice = hypo
  1060. else:
  1061. # else pad to sent_max_len
  1062. num_pad_tokens = sent_max_len - sent_lengths[i]
  1063. padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
  1064. decoded_slice = tf.concat([hypo, padding], axis=-1)
  1065. # finish sentence with EOS token
  1066. if sent_lengths[i] < max_length:
  1067. decoded_slice = tf.where(
  1068. tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
  1069. eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
  1070. decoded_slice,
  1071. )
  1072. # add to list
  1073. decoded_list.append(decoded_slice)
  1074. decoded = tf.stack(decoded_list)
  1075. else:
  1076. # none of the hypotheses have an eos_token
  1077. assert (len(hypo) == max_length for hypo in best)
  1078. decoded = tf.stack(best)
  1079. return decoded
  1080. @staticmethod
  1081. def _reorder_cache(past, beam_idx):
  1082. reordered_past = []
  1083. for layer_past in past:
  1084. # get the correct batch idx from layer past batch dim
  1085. # batch dim of `past` and `mems` is at 2nd position
  1086. reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
  1087. reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
  1088. # check that shape matches
  1089. assert shape_list(reordered_layer_past) == shape_list(layer_past)
  1090. reordered_past.append(reordered_layer_past)
  1091. past = tuple(reordered_past)
  1092. return past
  1093. def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
  1094. # create logit penalties for already seen input_ids
  1095. token_penalties = np.ones(shape_list(logits))
  1096. prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
  1097. for i, prev_input_id in enumerate(prev_input_ids):
  1098. logit_penalized = logits[i].numpy()[prev_input_id]
  1099. logit_penalties = np.zeros(logit_penalized.shape)
  1100. # if previous logit score is < 0 then multiply repetition penalty else divide
  1101. logit_penalties[logit_penalized < 0] = repetition_penalty
  1102. logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
  1103. np.put(token_penalties[i], prev_input_id, logit_penalties)
  1104. return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
  1105. def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
  1106. # Copied from fairseq for no_repeat_ngram in beam_search"""
  1107. if cur_len + 1 < no_repeat_ngram_size:
  1108. # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
  1109. return [[] for _ in range(num_hypos)]
  1110. generated_ngrams = [{} for _ in range(num_hypos)]
  1111. for idx in range(num_hypos):
  1112. gen_tokens = prev_input_ids[idx].numpy().tolist()
  1113. generated_ngram = generated_ngrams[idx]
  1114. for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
  1115. prev_ngram_tuple = tuple(ngram[:-1])
  1116. generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
  1117. def _get_generated_ngrams(hypo_idx):
  1118. # Before decoding the next token, prevent decoding of ngrams that have already appeared
  1119. start_idx = cur_len + 1 - no_repeat_ngram_size
  1120. ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
  1121. return generated_ngrams[hypo_idx].get(ngram_idx, [])
  1122. banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
  1123. return banned_tokens
  1124. def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
  1125. banned_tokens = []
  1126. def _tokens_match(prev_tokens, tokens):
  1127. if len(tokens) == 0:
  1128. # if bad word tokens is just one token always ban it
  1129. return True
  1130. if len(tokens) > len(prev_input_ids):
  1131. # if bad word tokens are longer then prev input_ids they can't be equal
  1132. return False
  1133. if prev_tokens[-len(tokens) :] == tokens:
  1134. # if tokens match
  1135. return True
  1136. else:
  1137. return False
  1138. for prev_input_ids_slice in prev_input_ids:
  1139. banned_tokens_slice = []
  1140. for banned_token_seq in bad_words_ids:
  1141. assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
  1142. bad_words_ids
  1143. )
  1144. if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
  1145. # if tokens do not match continue
  1146. continue
  1147. banned_tokens_slice.append(banned_token_seq[-1])
  1148. banned_tokens.append(banned_tokens_slice)
  1149. return banned_tokens
  1150. def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
  1151. """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
  1152. Args:
  1153. logits: logits distribution shape (batch size, vocabulary size)
  1154. if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
  1155. if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
  1156. Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
  1157. Make sure we keep at least min_tokens_to_keep per batch example in the output
  1158. From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
  1159. """
  1160. logits_shape = shape_list(logits)
  1161. if top_k > 0:
  1162. top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
  1163. # Remove all tokens with a probability less than the last token of the top-k
  1164. indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
  1165. logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
  1166. if top_p < 1.0:
  1167. sorted_indices = tf.argsort(logits, direction="DESCENDING")
  1168. sorted_logits = tf.gather(
  1169. logits, sorted_indices, axis=-1, batch_dims=1
  1170. ) # expects logits to be of dim (batch_size, vocab_size)
  1171. cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
  1172. # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
  1173. sorted_indices_to_remove = cumulative_probs > top_p
  1174. if min_tokens_to_keep > 1:
  1175. # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
  1176. sorted_indices_to_remove = tf.concat(
  1177. [
  1178. tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
  1179. sorted_indices_to_remove[:, min_tokens_to_keep:],
  1180. ],
  1181. -1,
  1182. )
  1183. # Shift the indices to the right to keep also the first token above the threshold
  1184. sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
  1185. sorted_indices_to_remove = tf.concat(
  1186. [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
  1187. )
  1188. # scatter sorted tensors to original indexing
  1189. indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
  1190. logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
  1191. return logits
  1192. def scatter_values_on_batch_indices(values, batch_indices):
  1193. shape = shape_list(batch_indices)
  1194. # broadcast batch dim to shape
  1195. broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
  1196. # transform batch_indices to pair_indices
  1197. pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
  1198. # scatter values to pair indices
  1199. return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
  1200. def set_tensor_by_indices_to_value(tensor, indices, value):
  1201. # create value_tensor since tensor value assignment is not possible in TF
  1202. value_tensor = tf.zeros_like(tensor) + value
  1203. return tf.where(indices, value_tensor, tensor)
  1204. class BeamHypotheses(object):
  1205. def __init__(self, num_beams, max_length, length_penalty, early_stopping):
  1206. """
  1207. Initialize n-best list of hypotheses.
  1208. """
  1209. self.max_length = max_length - 1 # ignoring bos_token
  1210. self.length_penalty = length_penalty
  1211. self.early_stopping = early_stopping
  1212. self.num_beams = num_beams
  1213. self.beams = []
  1214. self.worst_score = 1e9
  1215. def __len__(self):
  1216. """
  1217. Number of hypotheses in the list.
  1218. """
  1219. return len(self.beams)
  1220. def add(self, hyp, sum_logprobs):
  1221. """
  1222. Add a new hypothesis to the list.
  1223. """
  1224. score = sum_logprobs / len(hyp) ** self.length_penalty
  1225. if len(self) < self.num_beams or score > self.worst_score:
  1226. self.beams.append((score, hyp))
  1227. if len(self) > self.num_beams:
  1228. sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
  1229. del self.beams[sorted_scores[0][1]]
  1230. self.worst_score = sorted_scores[1][0]
  1231. else:
  1232. self.worst_score = min(score, self.worst_score)
  1233. def is_done(self, best_sum_logprobs, cur_len=None):
  1234. """
  1235. If there are enough hypotheses and that none of the hypotheses being generated
  1236. can become better than the worst one in the heap, then we are done with this sentence.
  1237. """
  1238. if len(self) < self.num_beams:
  1239. return False
  1240. elif self.early_stopping:
  1241. return True
  1242. else:
  1243. if cur_len is None:
  1244. cur_len = self.max_length
  1245. cur_score = best_sum_logprobs / cur_len ** self.length_penalty
  1246. ret = self.worst_score >= cur_score
  1247. return ret
  1248. class TFConv1D(tf.keras.layers.Layer):
  1249. def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
  1250. """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
  1251. Basically works like a Linear layer but the weights are transposed
  1252. """
  1253. super().__init__(**kwargs)
  1254. self.nf = nf
  1255. self.nx = nx
  1256. self.initializer_range = initializer_range
  1257. def build(self, input_shape):
  1258. self.weight = self.add_weight(
  1259. "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
  1260. )
  1261. self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
  1262. def call(self, x):
  1263. bz, sl = shape_list(x)[:2]
  1264. x = tf.reshape(x, [-1, self.nx])
  1265. x = tf.matmul(x, self.weight) + self.bias
  1266. x = tf.reshape(x, [bz, sl, self.nf])
  1267. return x
  1268. class TFSharedEmbeddings(tf.keras.layers.Layer):
  1269. """Construct shared token embeddings.
  1270. """
  1271. def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
  1272. super().__init__(**kwargs)
  1273. self.vocab_size = vocab_size
  1274. self.hidden_size = hidden_size
  1275. self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
  1276. def build(self, input_shape):
  1277. """Build shared token embedding layer
  1278. Shared weights logic adapted from
  1279. https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
  1280. """
  1281. self.weight = self.add_weight(
  1282. "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
  1283. )
  1284. super().build(input_shape)
  1285. def call(self, inputs, mode="embedding"):
  1286. """Get token embeddings of inputs.
  1287. Args:
  1288. inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
  1289. mode: string, a valid value is one of "embedding" and "linear".
  1290. Returns:
  1291. outputs: (1) If mode == "embedding", output embedding tensor, float32 with
  1292. shape [batch_size, length, embedding_size]; (2) mode == "linear", output
  1293. linear tensor, float32 with shape [batch_size, length, vocab_size].
  1294. Raises:
  1295. ValueError: if mode is not valid.
  1296. Shared weights logic adapted from
  1297. https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
  1298. """
  1299. if mode == "embedding":
  1300. return self._embedding(inputs)
  1301. elif mode == "linear":
  1302. return self._linear(inputs)
  1303. else:
  1304. raise ValueError("mode {} is not valid.".format(mode))
  1305. def _embedding(self, input_ids):
  1306. """Applies embedding based on inputs tensor."""
  1307. return tf.gather(self.weight, input_ids)
  1308. def _linear(self, inputs):
  1309. """Computes logits by running inputs through a linear layer.
  1310. Args:
  1311. inputs: A float32 tensor with shape [..., hidden_size]
  1312. Returns:
  1313. float32 tensor with shape [..., vocab_size].
  1314. """
  1315. first_dims = shape_list(inputs)[:-1]
  1316. x = tf.reshape(inputs, [-1, self.hidden_size])
  1317. logits = tf.matmul(x, self.weight, transpose_b=True)
  1318. return tf.reshape(logits, first_dims + [self.vocab_size])
  1319. class TFSequenceSummary(tf.keras.layers.Layer):
  1320. r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
  1321. Args of the config class:
  1322. summary_type:
  1323. - 'last' => [default] take the last token hidden state (like XLNet)
  1324. - 'first' => take the first token hidden state (like Bert)
  1325. - 'mean' => take the mean of all tokens hidden states
  1326. - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
  1327. - 'attn' => Not implemented now, use multi-head attention
  1328. summary_use_proj: Add a projection after the vector extraction
  1329. summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
  1330. summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
  1331. summary_first_dropout: Add a dropout before the projection and activation
  1332. summary_last_dropout: Add a dropout after the projection and activation
  1333. """
  1334. def __init__(self, config, initializer_range=0.02, **kwargs):
  1335. super().__init__(**kwargs)
  1336. self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
  1337. if self.summary_type == "attn":
  1338. # We should use a standard multi-head attention module with absolute positional embedding for that.
  1339. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  1340. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  1341. raise NotImplementedError
  1342. self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
  1343. if self.has_summary:
  1344. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  1345. num_classes = config.num_labels
  1346. else:
  1347. num_classes = config.hidden_size
  1348. self.summary = tf.keras.layers.Dense(
  1349. num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
  1350. )
  1351. self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
  1352. if self.has_activation:
  1353. self.activation = tf.keras.activations.tanh
  1354. self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
  1355. if self.has_first_dropout:
  1356. self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
  1357. self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
  1358. if self.has_last_dropout:
  1359. self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
  1360. def call(self, inputs, training=False):
  1361. """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
  1362. cls_index: [optional] position of the classification token if summary_type == 'cls_index',
  1363. shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
  1364. if summary_type == 'cls_index' and cls_index is None:
  1365. we take the last token of the sequence as classification token
  1366. """
  1367. if not isinstance(inputs, (dict, tuple, list)):
  1368. hidden_states = inputs
  1369. cls_index = None
  1370. elif isinstance(inputs, (tuple, list)):
  1371. hidden_states = inputs[0]
  1372. cls_index = inputs[1] if len(inputs) > 1 else None
  1373. assert len(inputs) <= 2, "Too many inputs."
  1374. else:
  1375. hidden_states = inputs.get("hidden_states")
  1376. cls_index = inputs.get("cls_index", None)
  1377. if self.summary_type == "last":
  1378. output = hidden_states[:, -1]
  1379. elif self.summary_type == "first":
  1380. output = hidden_states[:, 0]
  1381. elif self.summary_type == "mean":
  1382. output = tf.reduce_mean(hidden_states, axis=1)
  1383. elif self.summary_type == "cls_index":
  1384. hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
  1385. if cls_index is None:
  1386. cls_index = tf.fill(
  1387. hidden_shape[:-2], hidden_shape[-2] - 1
  1388. ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
  1389. cls_shape = shape_list(cls_index)
  1390. if len(cls_shape) <= len(hidden_shape) - 2:
  1391. cls_index = cls_index[..., tf.newaxis]
  1392. # else:
  1393. # cls_index = cls_index[..., tf.newaxis]
  1394. # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
  1395. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  1396. output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
  1397. output = tf.squeeze(
  1398. output, axis=len(hidden_shape) - 2
  1399. ) # shape of output: (batch, num choices, hidden_size)
  1400. elif self.summary_type == "attn":
  1401. raise NotImplementedError
  1402. if self.has_first_dropout:
  1403. output = self.first_dropout(output, training=training)
  1404. if self.has_summary:
  1405. output = self.summary(output)
  1406. if self.has_activation:
  1407. output = self.activation(output)
  1408. if self.has_last_dropout:
  1409. output = self.last_dropout(output, training=training)
  1410. return output
  1411. def shape_list(x):
  1412. """Deal with dynamic shape in tensorflow cleanly."""
  1413. static = x.shape.as_list()
  1414. dynamic = tf.shape(x)
  1415. return [dynamic[i] if s is None else s for i, s in enumerate(static)]
  1416. def get_initializer(initializer_range=0.02):
  1417. """Creates a `tf.initializers.truncated_normal` with the given range.
  1418. Args:
  1419. initializer_range: float, initializer range for stddev.
  1420. Returns:
  1421. TruncatedNormal initializer with stddev = `initializer_range`.
  1422. """
  1423. return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
  1424. TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
  1425. "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
  1426. "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
  1427. "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
  1428. "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
  1429. "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
  1430. "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
  1431. "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
  1432. "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
  1433. "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
  1434. "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
  1435. "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
  1436. "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
  1437. "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
  1438. "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
  1439. "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
  1440. "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
  1441. "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
  1442. "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
  1443. "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
  1444. "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/tf_model.h5",
  1445. }
  1446. def gelu(x):
  1447. """ Gaussian Error Linear Unit.
  1448. Original Implementation of the gelu activation function in Google Bert repo when initially created.
  1449. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
  1450. 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
  1451. Also see https://arxiv.org/abs/1606.08415
  1452. """
  1453. cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
  1454. return x * cdf
  1455. def gelu_new(x):
  1456. """Gaussian Error Linear Unit.
  1457. This is a smoother version of the RELU.
  1458. Original paper: https://arxiv.org/abs/1606.08415
  1459. Args:
  1460. x: float Tensor to perform activation.
  1461. Returns:
  1462. `x` with the GELU activation applied.
  1463. """
  1464. cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
  1465. return x * cdf
  1466. def swish(x):
  1467. return x * tf.sigmoid(x)
  1468. ACT2FN = {
  1469. "gelu": tf.keras.layers.Activation(gelu),
  1470. "relu": tf.keras.activations.relu,
  1471. "swish": tf.keras.layers.Activation(swish),
  1472. "gelu_new": tf.keras.layers.Activation(gelu_new),
  1473. }
  1474. class TFBertEmbeddings(tf.keras.layers.Layer):
  1475. """Construct the embeddings from word, position and token_type embeddings.
  1476. """
  1477. def __init__(self, config, **kwargs):
  1478. super().__init__(**kwargs)
  1479. self.vocab_size = config.vocab_size
  1480. self.hidden_size = config.hidden_size
  1481. self.initializer_range = config.initializer_range
  1482. self.position_embeddings = tf.keras.layers.Embedding(
  1483. config.max_position_embeddings,
  1484. config.hidden_size,
  1485. embeddings_initializer=get_initializer(self.initializer_range),
  1486. name="position_embeddings",
  1487. )
  1488. self.token_type_embeddings = tf.keras.layers.Embedding(
  1489. config.type_vocab_size,
  1490. config.hidden_size,
  1491. embeddings_initializer=get_initializer(self.initializer_range),
  1492. name="token_type_embeddings",
  1493. )
  1494. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  1495. # any TensorFlow checkpoint file
  1496. self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  1497. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  1498. def build(self, input_shape):
  1499. """Build shared word embedding layer """
  1500. with tf.name_scope("word_embeddings"):
  1501. # Create and initialize weights. The random normal initializer was chosen
  1502. # arbitrarily, and works well.
  1503. self.word_embeddings = self.add_weight(
  1504. "weight",
  1505. shape=[self.vocab_size, self.hidden_size],
  1506. initializer=get_initializer(self.initializer_range),
  1507. )
  1508. super().build(input_shape)
  1509. def call(self, inputs, mode="embedding", training=False):
  1510. """Get token embeddings of inputs.
  1511. Args:
  1512. inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
  1513. mode: string, a valid value is one of "embedding" and "linear".
  1514. Returns:
  1515. outputs: (1) If mode == "embedding", output embedding tensor, float32 with
  1516. shape [batch_size, length, embedding_size]; (2) mode == "linear", output
  1517. linear tensor, float32 with shape [batch_size, length, vocab_size].
  1518. Raises:
  1519. ValueError: if mode is not valid.
  1520. Shared weights logic adapted from
  1521. https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
  1522. """
  1523. if mode == "embedding":
  1524. return self._embedding(inputs, training=training)
  1525. elif mode == "linear":
  1526. return self._linear(inputs)
  1527. else:
  1528. raise ValueError("mode {} is not valid.".format(mode))
  1529. def _embedding(self, inputs, training=False):
  1530. """Applies embedding based on inputs tensor."""
  1531. input_ids, position_ids, token_type_ids, inputs_embeds = inputs
  1532. if input_ids is not None:
  1533. input_shape = shape_list(input_ids)
  1534. else:
  1535. input_shape = shape_list(inputs_embeds)[:-1]
  1536. seq_length = input_shape[1]
  1537. if position_ids is None:
  1538. position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
  1539. if token_type_ids is None:
  1540. token_type_ids = tf.fill(input_shape, 0)
  1541. if inputs_embeds is None:
  1542. inputs_embeds = tf.gather(self.word_embeddings, input_ids)
  1543. position_embeddings = self.position_embeddings(position_ids)
  1544. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  1545. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  1546. embeddings = self.LayerNorm(embeddings)
  1547. embeddings = self.dropout(embeddings, training=training)
  1548. return embeddings
  1549. def _linear(self, inputs):
  1550. """Computes logits by running inputs through a linear layer.
  1551. Args:
  1552. inputs: A float32 tensor with shape [batch_size, length, hidden_size]
  1553. Returns:
  1554. float32 tensor with shape [batch_size, length, vocab_size].
  1555. """
  1556. batch_size = shape_list(inputs)[0]
  1557. length = shape_list(inputs)[1]
  1558. x = tf.reshape(inputs, [-1, self.hidden_size])
  1559. logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
  1560. return tf.reshape(logits, [batch_size, length, self.vocab_size])
  1561. class TFBertSelfAttention(tf.keras.layers.Layer):
  1562. def __init__(self, config, **kwargs):
  1563. super().__init__(**kwargs)
  1564. if config.hidden_size % config.num_attention_heads != 0:
  1565. raise ValueError(
  1566. "The hidden size (%d) is not a multiple of the number of attention "
  1567. "heads (%d)" % (config.hidden_size, config.num_attention_heads)
  1568. )
  1569. self.output_attentions = config.output_attentions
  1570. self.num_attention_heads = config.num_attention_heads
  1571. assert config.hidden_size % config.num_attention_heads == 0
  1572. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  1573. self.all_head_size = self.num_attention_heads * self.attention_head_size
  1574. self.amp = config.amp
  1575. self.query = tf.keras.layers.Dense(
  1576. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
  1577. )
  1578. self.key = tf.keras.layers.Dense(
  1579. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
  1580. )
  1581. self.value = tf.keras.layers.Dense(
  1582. self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
  1583. )
  1584. self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
  1585. def transpose_for_scores(self, x, batch_size):
  1586. x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
  1587. return tf.transpose(x, perm=[0, 2, 1, 3])
  1588. def call(self, inputs, training=False):
  1589. hidden_states, attention_mask, head_mask = inputs
  1590. batch_size = shape_list(hidden_states)[0]
  1591. mixed_query_layer = self.query(hidden_states)
  1592. mixed_key_layer = self.key(hidden_states)
  1593. mixed_value_layer = self.value(hidden_states)
  1594. query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
  1595. key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
  1596. value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
  1597. # Take the dot product between "query" and "key" to get the raw attention scores.
  1598. attention_scores = tf.matmul(
  1599. query_layer, key_layer, transpose_b=True
  1600. ) # (batch size, num_heads, seq_len_q, seq_len_k)
  1601. dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
  1602. attention_scores = attention_scores / tf.cast(tf.math.sqrt(dk), tf.float16 if self.amp else tf.float32)
  1603. if attention_mask is not None:
  1604. # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
  1605. attention_scores = attention_scores + attention_mask
  1606. # Normalize the attention scores to probabilities.
  1607. attention_probs = tf.nn.softmax(attention_scores, axis=-1)
  1608. # This is actually dropping out entire tokens to attend to, which might
  1609. # seem a bit unusual, but is taken from the original Transformer paper.
  1610. attention_probs = self.dropout(attention_probs, training=training)
  1611. # Mask heads if we want to
  1612. if head_mask is not None:
  1613. attention_probs = attention_probs * head_mask
  1614. context_layer = tf.matmul(attention_probs, value_layer)
  1615. context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
  1616. context_layer = tf.reshape(
  1617. context_layer, (batch_size, -1, self.all_head_size)
  1618. ) # (batch_size, seq_len_q, all_head_size)
  1619. outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
  1620. return outputs
  1621. class TFBertSelfOutput(tf.keras.layers.Layer):
  1622. def __init__(self, config, **kwargs):
  1623. super().__init__(**kwargs)
  1624. self.dense = tf.keras.layers.Dense(
  1625. config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  1626. )
  1627. self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  1628. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  1629. def call(self, inputs, training=False):
  1630. hidden_states, input_tensor = inputs
  1631. hidden_states = self.dense(hidden_states)
  1632. hidden_states = self.dropout(hidden_states, training=training)
  1633. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  1634. return hidden_states
  1635. class TFBertAttention(tf.keras.layers.Layer):
  1636. def __init__(self, config, **kwargs):
  1637. super().__init__(**kwargs)
  1638. self.self_attention = TFBertSelfAttention(config, name="self")
  1639. self.dense_output = TFBertSelfOutput(config, name="output")
  1640. def prune_heads(self, heads):
  1641. raise NotImplementedError
  1642. def call(self, inputs, training=False):
  1643. input_tensor, attention_mask, head_mask = inputs
  1644. self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
  1645. attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
  1646. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  1647. return outputs
  1648. class TFBertIntermediate(tf.keras.layers.Layer):
  1649. def __init__(self, config, **kwargs):
  1650. super().__init__(**kwargs)
  1651. self.dense = tf.keras.layers.Dense(
  1652. config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  1653. )
  1654. if isinstance(config.hidden_act, str):
  1655. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  1656. else:
  1657. self.intermediate_act_fn = config.hidden_act
  1658. def call(self, hidden_states):
  1659. hidden_states = self.dense(hidden_states)
  1660. hidden_states = self.intermediate_act_fn(hidden_states)
  1661. return hidden_states
  1662. class TFBertOutput(tf.keras.layers.Layer):
  1663. def __init__(self, config, **kwargs):
  1664. super().__init__(**kwargs)
  1665. self.dense = tf.keras.layers.Dense(
  1666. config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  1667. )
  1668. self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  1669. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  1670. def call(self, inputs, training=False):
  1671. hidden_states, input_tensor = inputs
  1672. hidden_states = self.dense(hidden_states)
  1673. hidden_states = self.dropout(hidden_states, training=training)
  1674. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  1675. return hidden_states
  1676. class TFBertLayer(tf.keras.layers.Layer):
  1677. def __init__(self, config, **kwargs):
  1678. super().__init__(**kwargs)
  1679. self.attention = TFBertAttention(config, name="attention")
  1680. self.intermediate = TFBertIntermediate(config, name="intermediate")
  1681. self.bert_output = TFBertOutput(config, name="output")
  1682. def call(self, inputs, training=False):
  1683. hidden_states, attention_mask, head_mask = inputs
  1684. attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
  1685. attention_output = attention_outputs[0]
  1686. intermediate_output = self.intermediate(attention_output)
  1687. layer_output = self.bert_output([intermediate_output, attention_output], training=training)
  1688. outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
  1689. return outputs
  1690. class TFBertEncoder(tf.keras.layers.Layer):
  1691. def __init__(self, config, **kwargs):
  1692. super().__init__(**kwargs)
  1693. self.output_attentions = config.output_attentions
  1694. self.output_hidden_states = config.output_hidden_states
  1695. self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
  1696. def call(self, inputs, training=False):
  1697. hidden_states, attention_mask, head_mask = inputs
  1698. all_hidden_states = ()
  1699. all_attentions = ()
  1700. for i, layer_module in enumerate(self.layer):
  1701. if self.output_hidden_states:
  1702. all_hidden_states = all_hidden_states + (hidden_states,)
  1703. layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
  1704. hidden_states = layer_outputs[0]
  1705. if self.output_attentions:
  1706. all_attentions = all_attentions + (layer_outputs[1],)
  1707. # Add last layer
  1708. if self.output_hidden_states:
  1709. all_hidden_states = all_hidden_states + (hidden_states,)
  1710. outputs = (hidden_states,)
  1711. if self.output_hidden_states:
  1712. outputs = outputs + (all_hidden_states,)
  1713. if self.output_attentions:
  1714. outputs = outputs + (all_attentions,)
  1715. return outputs # outputs, (hidden states), (attentions)
  1716. class TFBertPooler(tf.keras.layers.Layer):
  1717. def __init__(self, config, **kwargs):
  1718. super().__init__(**kwargs)
  1719. self.dense = tf.keras.layers.Dense(
  1720. config.hidden_size,
  1721. kernel_initializer=get_initializer(config.initializer_range),
  1722. activation="tanh",
  1723. name="dense",
  1724. )
  1725. def call(self, hidden_states):
  1726. # We "pool" the model by simply taking the hidden state corresponding
  1727. # to the first token.
  1728. first_token_tensor = hidden_states[:, 0]
  1729. pooled_output = self.dense(first_token_tensor)
  1730. return pooled_output
  1731. class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
  1732. def __init__(self, config, **kwargs):
  1733. super().__init__(**kwargs)
  1734. self.dense = tf.keras.layers.Dense(
  1735. config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
  1736. )
  1737. if isinstance(config.hidden_act, str):
  1738. self.transform_act_fn = ACT2FN[config.hidden_act]
  1739. else:
  1740. self.transform_act_fn = config.hidden_act
  1741. self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
  1742. def call(self, hidden_states):
  1743. hidden_states = self.dense(hidden_states)
  1744. hidden_states = self.transform_act_fn(hidden_states)
  1745. hidden_states = self.LayerNorm(hidden_states)
  1746. return hidden_states
  1747. class TFBertLMPredictionHead(tf.keras.layers.Layer):
  1748. def __init__(self, config, input_embeddings, **kwargs):
  1749. super().__init__(**kwargs)
  1750. self.vocab_size = config.vocab_size
  1751. self.transform = TFBertPredictionHeadTransform(config, name="transform")
  1752. # The output weights are the same as the input embeddings, but there is
  1753. # an output-only bias for each token.
  1754. self.input_embeddings = input_embeddings
  1755. def build(self, input_shape):
  1756. self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
  1757. super().build(input_shape)
  1758. def call(self, hidden_states):
  1759. hidden_states = self.transform(hidden_states)
  1760. hidden_states = self.input_embeddings(hidden_states, mode="linear")
  1761. hidden_states = hidden_states + self.bias
  1762. return hidden_states
  1763. class TFBertMLMHead(tf.keras.layers.Layer):
  1764. def __init__(self, config, input_embeddings, **kwargs):
  1765. super().__init__(**kwargs)
  1766. self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
  1767. def call(self, sequence_output):
  1768. prediction_scores = self.predictions(sequence_output)
  1769. return prediction_scores
  1770. class TFBertNSPHead(tf.keras.layers.Layer):
  1771. def __init__(self, config, **kwargs):
  1772. super().__init__(**kwargs)
  1773. self.seq_relationship = tf.keras.layers.Dense(
  1774. 2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
  1775. )
  1776. def call(self, pooled_output):
  1777. seq_relationship_score = self.seq_relationship(pooled_output)
  1778. return seq_relationship_score
  1779. @keras_serializable
  1780. class TFBertMainLayer(tf.keras.layers.Layer):
  1781. config_class = BertConfig
  1782. def __init__(self, config, **kwargs):
  1783. super().__init__(**kwargs)
  1784. self.num_hidden_layers = config.num_hidden_layers
  1785. self.embeddings = TFBertEmbeddings(config, name="embeddings")
  1786. self.encoder = TFBertEncoder(config, name="encoder")
  1787. self.pooler = TFBertPooler(config, name="pooler")
  1788. def get_input_embeddings(self):
  1789. return self.embeddings
  1790. def _resize_token_embeddings(self, new_num_tokens):
  1791. raise NotImplementedError
  1792. def _prune_heads(self, heads_to_prune):
  1793. """ Prunes heads of the model.
  1794. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
  1795. See base class PreTrainedModel
  1796. """
  1797. raise NotImplementedError
  1798. def call(
  1799. self,
  1800. inputs,
  1801. attention_mask=None,
  1802. token_type_ids=None,
  1803. position_ids=None,
  1804. head_mask=None,
  1805. inputs_embeds=None,
  1806. training=False,
  1807. ):
  1808. if isinstance(inputs, (tuple, list)):
  1809. input_ids = inputs[0]
  1810. attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
  1811. token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
  1812. position_ids = inputs[3] if len(inputs) > 3 else position_ids
  1813. head_mask = inputs[4] if len(inputs) > 4 else head_mask
  1814. inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
  1815. assert len(inputs) <= 6, "Too many inputs."
  1816. elif isinstance(inputs, (dict, BatchEncoding)):
  1817. input_ids = inputs.get("input_ids")
  1818. attention_mask = inputs.get("attention_mask", attention_mask)
  1819. token_type_ids = inputs.get("token_type_ids", token_type_ids)
  1820. position_ids = inputs.get("position_ids", position_ids)
  1821. head_mask = inputs.get("head_mask", head_mask)
  1822. inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
  1823. assert len(inputs) <= 6, "Too many inputs."
  1824. else:
  1825. input_ids = inputs
  1826. if input_ids is not None and inputs_embeds is not None:
  1827. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1828. elif input_ids is not None:
  1829. input_shape = shape_list(input_ids)
  1830. elif inputs_embeds is not None:
  1831. input_shape = shape_list(inputs_embeds)[:-1]
  1832. else:
  1833. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1834. if attention_mask is None:
  1835. attention_mask = tf.fill(input_shape, 1)
  1836. if token_type_ids is None:
  1837. token_type_ids = tf.fill(input_shape, 0)
  1838. # We create a 3D attention mask from a 2D tensor mask.
  1839. # Sizes are [batch_size, 1, 1, to_seq_length]
  1840. # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
  1841. # this attention mask is more simple than the triangular masking of causal attention
  1842. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
  1843. extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
  1844. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  1845. # masked positions, this operation will create a tensor which is 0.0 for
  1846. # positions we want to attend and -10000.0 for masked positions.
  1847. # Since we are adding it to the raw scores before the softmax, this is
  1848. # effectively the same as removing these entirely.
  1849. extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
  1850. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  1851. # Prepare head mask if needed
  1852. # 1.0 in head_mask indicate we keep the head
  1853. # attention_probs has shape bsz x n_heads x N x N
  1854. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  1855. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1856. if head_mask is not None:
  1857. raise NotImplementedError
  1858. else:
  1859. head_mask = [None] * self.num_hidden_layers
  1860. # head_mask = tf.constant([0] * self.num_hidden_layers)
  1861. embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
  1862. encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
  1863. sequence_output = encoder_outputs[0]
  1864. pooled_output = self.pooler(sequence_output)
  1865. outputs = (sequence_output, pooled_output,) + encoder_outputs[
  1866. 1:
  1867. ] # add hidden_states and attentions if they are here
  1868. return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
  1869. class TFBertPreTrainedModel(TFPreTrainedModel):
  1870. """ An abstract class to handle weights initialization and
  1871. a simple interface for downloading and loading pretrained models.
  1872. """
  1873. config_class = BertConfig
  1874. pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
  1875. base_model_prefix = "bert"
  1876. BERT_START_DOCSTRING = r"""
  1877. This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
  1878. Use it as a regular TF 2.0 Keras Model and
  1879. refer to the TF 2.0 documentation for all matter related to general usage and behavior.
  1880. .. note::
  1881. TF 2.0 models accepts two formats as inputs:
  1882. - having all inputs as keyword arguments (like PyTorch models), or
  1883. - having all inputs as a list, tuple or dict in the first positional arguments.
  1884. This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
  1885. all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
  1886. If you choose this second option, there are three possibilities you can use to gather all the input Tensors
  1887. in the first positional argument :
  1888. - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
  1889. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  1890. :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
  1891. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  1892. :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
  1893. Parameters:
  1894. config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
  1895. Initializing with a config file does not load the weights associated with the model, only the configuration.
  1896. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
  1897. """
  1898. BERT_INPUTS_DOCSTRING = r"""
  1899. Args:
  1900. input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
  1901. Indices of input sequence tokens in the vocabulary.
  1902. Indices can be obtained using :class:`transformers.BertTokenizer`.
  1903. See :func:`transformers.PreTrainedTokenizer.encode` and
  1904. :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
  1905. `What are input IDs? <../glossary.html#input-ids>`__
  1906. attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
  1907. Mask to avoid performing attention on padding token indices.
  1908. Mask values selected in ``[0, 1]``:
  1909. ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
  1910. `What are attention masks? <../glossary.html#attention-mask>`__
  1911. token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
  1912. Segment token indices to indicate first and second portions of the inputs.
  1913. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
  1914. corresponds to a `sentence B` token
  1915. `What are token type IDs? <../glossary.html#token-type-ids>`__
  1916. position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
  1917. Indices of positions of each input sequence tokens in the position embeddings.
  1918. Selected in the range ``[0, config.max_position_embeddings - 1]``.
  1919. `What are position IDs? <../glossary.html#position-ids>`__
  1920. head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
  1921. Mask to nullify selected heads of the self-attention modules.
  1922. Mask values selected in ``[0, 1]``:
  1923. :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
  1924. inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
  1925. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
  1926. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1927. than the model's internal embedding lookup matrix.
  1928. training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
  1929. Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
  1930. (if set to :obj:`False`) for evaluation.
  1931. """
  1932. @add_start_docstrings(
  1933. "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
  1934. BERT_START_DOCSTRING,
  1935. )
  1936. class TFBertModel(TFBertPreTrainedModel):
  1937. def __init__(self, config, *inputs, **kwargs):
  1938. super().__init__(config, *inputs, **kwargs)
  1939. self.bert = TFBertMainLayer(config, name="bert")
  1940. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  1941. def call(self, inputs, **kwargs):
  1942. r"""
  1943. Returns:
  1944. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  1945. last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
  1946. Sequence of hidden-states at the output of the last layer of the model.
  1947. pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
  1948. Last layer hidden-state of the first token of the sequence (classification token)
  1949. further processed by a Linear layer and a Tanh activation function. The Linear
  1950. layer weights are trained from the next sentence prediction (classification)
  1951. objective during Bert pretraining. This output is usually *not* a good summary
  1952. of the semantic content of the input, you're often better with averaging or pooling
  1953. the sequence of hidden-states for the whole input sequence.
  1954. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  1955. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  1956. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  1957. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  1958. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  1959. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  1960. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
  1961. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  1962. Examples::
  1963. import tensorflow as tf
  1964. from transformers import BertTokenizer, TFBertModel
  1965. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  1966. model = TFBertModel.from_pretrained('bert-base-uncased')
  1967. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  1968. outputs = model(input_ids)
  1969. last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
  1970. """
  1971. outputs = self.bert(inputs, **kwargs)
  1972. return outputs
  1973. @add_start_docstrings(
  1974. """Bert Model with two heads on top as done during the pre-training:
  1975. a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
  1976. BERT_START_DOCSTRING,
  1977. )
  1978. class TFBertForPreTraining(TFBertPreTrainedModel):
  1979. def __init__(self, config, *inputs, **kwargs):
  1980. super().__init__(config, *inputs, **kwargs)
  1981. self.bert = TFBertMainLayer(config, name="bert")
  1982. self.nsp = TFBertNSPHead(config, name="nsp___cls")
  1983. self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
  1984. def get_output_embeddings(self):
  1985. return self.bert.embeddings
  1986. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  1987. def call(self, inputs, **kwargs):
  1988. r"""
  1989. Return:
  1990. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  1991. prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
  1992. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  1993. seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
  1994. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
  1995. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  1996. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  1997. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  1998. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  1999. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2000. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2001. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2002. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2003. Examples::
  2004. import tensorflow as tf
  2005. from transformers import BertTokenizer, TFBertForPreTraining
  2006. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2007. model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
  2008. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2009. outputs = model(input_ids)
  2010. prediction_scores, seq_relationship_scores = outputs[:2]
  2011. """
  2012. outputs = self.bert(inputs, **kwargs)
  2013. sequence_output, pooled_output = outputs[:2]
  2014. prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
  2015. seq_relationship_score = self.nsp(pooled_output)
  2016. outputs = (prediction_scores, seq_relationship_score,) + outputs[
  2017. 2:
  2018. ] # add hidden states and attention if they are here
  2019. return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions)
  2020. @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
  2021. class TFBertForMaskedLM(TFBertPreTrainedModel):
  2022. def __init__(self, config, *inputs, **kwargs):
  2023. super().__init__(config, *inputs, **kwargs)
  2024. self.bert = TFBertMainLayer(config, name="bert")
  2025. self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
  2026. def get_output_embeddings(self):
  2027. return self.bert.embeddings
  2028. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2029. def call(self, inputs, **kwargs):
  2030. r"""
  2031. Return:
  2032. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2033. prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
  2034. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  2035. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2036. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2037. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2038. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2039. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2040. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2041. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2042. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2043. Examples::
  2044. import tensorflow as tf
  2045. from transformers import BertTokenizer, TFBertForMaskedLM
  2046. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2047. model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
  2048. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2049. outputs = model(input_ids)
  2050. prediction_scores = outputs[0]
  2051. """
  2052. outputs = self.bert(inputs, **kwargs)
  2053. sequence_output = outputs[0]
  2054. prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
  2055. outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
  2056. return outputs # prediction_scores, (hidden_states), (attentions)
  2057. @add_start_docstrings(
  2058. """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
  2059. )
  2060. class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
  2061. def __init__(self, config, *inputs, **kwargs):
  2062. super().__init__(config, *inputs, **kwargs)
  2063. self.bert = TFBertMainLayer(config, name="bert")
  2064. self.nsp = TFBertNSPHead(config, name="nsp___cls")
  2065. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2066. def call(self, inputs, **kwargs):
  2067. r"""
  2068. Return:
  2069. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2070. seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`)
  2071. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
  2072. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2073. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2074. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2075. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2076. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2077. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2078. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2079. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2080. Examples::
  2081. import tensorflow as tf
  2082. from transformers import BertTokenizer, TFBertForNextSentencePrediction
  2083. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2084. model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
  2085. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2086. outputs = model(input_ids)
  2087. seq_relationship_scores = outputs[0]
  2088. """
  2089. outputs = self.bert(inputs, **kwargs)
  2090. pooled_output = outputs[1]
  2091. seq_relationship_score = self.nsp(pooled_output)
  2092. outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
  2093. return outputs # seq_relationship_score, (hidden_states), (attentions)
  2094. @add_start_docstrings(
  2095. """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
  2096. the pooled output) e.g. for GLUE tasks. """,
  2097. BERT_START_DOCSTRING,
  2098. )
  2099. class TFBertForSequenceClassification(TFBertPreTrainedModel):
  2100. def __init__(self, config, *inputs, **kwargs):
  2101. super().__init__(config, *inputs, **kwargs)
  2102. self.num_labels = config.num_labels
  2103. self.bert = TFBertMainLayer(config, name="bert")
  2104. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  2105. self.classifier = tf.keras.layers.Dense(
  2106. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  2107. )
  2108. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2109. def call(self, inputs, **kwargs):
  2110. r"""
  2111. Return:
  2112. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2113. logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
  2114. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  2115. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2116. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2117. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2118. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2119. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2120. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2121. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2122. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2123. Examples::
  2124. import tensorflow as tf
  2125. from transformers import BertTokenizer, TFBertForSequenceClassification
  2126. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2127. model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
  2128. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2129. outputs = model(input_ids)
  2130. logits = outputs[0]
  2131. """
  2132. outputs = self.bert(inputs, **kwargs)
  2133. pooled_output = outputs[1]
  2134. pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
  2135. logits = self.classifier(pooled_output)
  2136. outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
  2137. return outputs # logits, (hidden_states), (attentions)
  2138. @add_start_docstrings(
  2139. """Bert Model with a multiple choice classification head on top (a linear layer on top of
  2140. the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
  2141. BERT_START_DOCSTRING,
  2142. )
  2143. class TFBertForMultipleChoice(TFBertPreTrainedModel):
  2144. def __init__(self, config, *inputs, **kwargs):
  2145. super().__init__(config, *inputs, **kwargs)
  2146. self.bert = TFBertMainLayer(config, name="bert")
  2147. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  2148. self.classifier = tf.keras.layers.Dense(
  2149. 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  2150. )
  2151. @property
  2152. def dummy_inputs(self):
  2153. """ Dummy inputs to build the network.
  2154. Returns:
  2155. tf.Tensor with dummy inputs
  2156. """
  2157. return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
  2158. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2159. def call(
  2160. self,
  2161. inputs,
  2162. attention_mask=None,
  2163. token_type_ids=None,
  2164. position_ids=None,
  2165. head_mask=None,
  2166. inputs_embeds=None,
  2167. training=False,
  2168. ):
  2169. r"""
  2170. Return:
  2171. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2172. classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
  2173. `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
  2174. Classification scores (before SoftMax).
  2175. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2176. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2177. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2178. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2179. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2180. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2181. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2182. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2183. Examples::
  2184. import tensorflow as tf
  2185. from transformers import BertTokenizer, TFBertForMultipleChoice
  2186. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2187. model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
  2188. choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
  2189. input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :] # Batch size 1, 2 choices
  2190. outputs = model(input_ids)
  2191. classification_scores = outputs[0]
  2192. """
  2193. if isinstance(inputs, (tuple, list)):
  2194. input_ids = inputs[0]
  2195. attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
  2196. token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
  2197. position_ids = inputs[3] if len(inputs) > 3 else position_ids
  2198. head_mask = inputs[4] if len(inputs) > 4 else head_mask
  2199. inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
  2200. assert len(inputs) <= 6, "Too many inputs."
  2201. elif isinstance(inputs, dict):
  2202. input_ids = inputs.get("input_ids")
  2203. attention_mask = inputs.get("attention_mask", attention_mask)
  2204. token_type_ids = inputs.get("token_type_ids", token_type_ids)
  2205. position_ids = inputs.get("position_ids", position_ids)
  2206. head_mask = inputs.get("head_mask", head_mask)
  2207. inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
  2208. assert len(inputs) <= 6, "Too many inputs."
  2209. else:
  2210. input_ids = inputs
  2211. if input_ids is not None:
  2212. num_choices = shape_list(input_ids)[1]
  2213. seq_length = shape_list(input_ids)[2]
  2214. else:
  2215. num_choices = shape_list(inputs_embeds)[1]
  2216. seq_length = shape_list(inputs_embeds)[2]
  2217. flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
  2218. flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
  2219. flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
  2220. flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
  2221. flat_inputs = [
  2222. flat_input_ids,
  2223. flat_attention_mask,
  2224. flat_token_type_ids,
  2225. flat_position_ids,
  2226. head_mask,
  2227. inputs_embeds,
  2228. ]
  2229. outputs = self.bert(flat_inputs, training=training)
  2230. pooled_output = outputs[1]
  2231. pooled_output = self.dropout(pooled_output, training=training)
  2232. logits = self.classifier(pooled_output)
  2233. reshaped_logits = tf.reshape(logits, (-1, num_choices))
  2234. outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
  2235. return outputs # reshaped_logits, (hidden_states), (attentions)
  2236. @add_start_docstrings(
  2237. """Bert Model with a token classification head on top (a linear layer on top of
  2238. the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
  2239. BERT_START_DOCSTRING,
  2240. )
  2241. class TFBertForTokenClassification(TFBertPreTrainedModel):
  2242. def __init__(self, config, *inputs, **kwargs):
  2243. super().__init__(config, *inputs, **kwargs)
  2244. self.num_labels = config.num_labels
  2245. self.bert = TFBertMainLayer(config, name="bert")
  2246. self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
  2247. self.classifier = tf.keras.layers.Dense(
  2248. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
  2249. )
  2250. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2251. def call(self, inputs, **kwargs):
  2252. r"""
  2253. Return:
  2254. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2255. scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
  2256. Classification scores (before SoftMax).
  2257. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2258. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2259. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2260. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2261. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2262. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2263. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2264. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2265. Examples::
  2266. import tensorflow as tf
  2267. from transformers import BertTokenizer, TFBertForTokenClassification
  2268. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2269. model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
  2270. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2271. outputs = model(input_ids)
  2272. scores = outputs[0]
  2273. """
  2274. outputs = self.bert(inputs, **kwargs)
  2275. sequence_output = outputs[0]
  2276. sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
  2277. logits = self.classifier(sequence_output)
  2278. outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
  2279. return outputs # scores, (hidden_states), (attentions)
  2280. @add_start_docstrings(
  2281. """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
  2282. the hidden-states output to compute `span start logits` and `span end logits`). """,
  2283. BERT_START_DOCSTRING,
  2284. )
  2285. class TFBertForQuestionAnswering(TFBertPreTrainedModel):
  2286. def __init__(self, config, *inputs, **kwargs):
  2287. super().__init__(config, *inputs, **kwargs)
  2288. self.num_labels = config.num_labels
  2289. self.bert = TFBertMainLayer(config, name="bert")
  2290. self.qa_outputs = tf.keras.layers.Dense(
  2291. config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
  2292. )
  2293. @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
  2294. def call(self, inputs, **kwargs):
  2295. r"""
  2296. Return:
  2297. :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
  2298. start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
  2299. Span-start scores (before SoftMax).
  2300. end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
  2301. Span-end scores (before SoftMax).
  2302. hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
  2303. tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
  2304. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
  2305. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  2306. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
  2307. tuple of :obj:`tf.Tensor` (one for each layer) of shape
  2308. :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
  2309. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
  2310. Examples::
  2311. import tensorflow as tf
  2312. from transformers import BertTokenizer, TFBertForQuestionAnswering
  2313. tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
  2314. model = TFBertForQuestionAnswering.from_pretrained('bert-base-uncased')
  2315. input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
  2316. outputs = model(input_ids)
  2317. start_scores, end_scores = outputs[:2]
  2318. """
  2319. outputs = self.bert(inputs, **kwargs)
  2320. sequence_output = outputs[0]
  2321. logits = self.qa_outputs(sequence_output)
  2322. start_logits, end_logits = tf.split(logits, 2, axis=-1)
  2323. start_logits = tf.squeeze(start_logits, axis=-1)
  2324. end_logits = tf.squeeze(end_logits, axis=-1)
  2325. outputs = (start_logits, end_logits,) + outputs[2:]
  2326. return outputs # start_logits, end_logits, (hidden_states), (attentions)