| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843 |
- # coding=utf-8
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """TF general model utils."""
- import functools
- import logging
- import os
- import h5py
- import numpy as np
- import tensorflow as tf
- from tensorflow.python.keras.saving import hdf5_format
- from configuration_utils import PretrainedConfig, BertConfig
- from file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
- from file_utils import MULTIPLE_CHOICE_DUMMY_INPUTS, add_start_docstrings, add_start_docstrings_to_callable
- from tokenization_utils import BatchEncoding
- from utils import log
- class TFModelUtilsMixin:
- """
- A few utilities for `tf.keras.Model`s, to be used as a mixin.
- """
- def num_parameters(self, only_trainable: bool = False) -> int:
- """
- Get number of (optionally, trainable) parameters in the model.
- """
- if only_trainable:
- return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
- else:
- return self.count_params()
- def keras_serializable(cls):
- """
- Decorate a Keras Layer class to support Keras serialization.
- This is done by:
- 1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
- serialization time
- 2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
- convert it to a config object for the actual layer initializer
- 3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
- not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`
- :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
- `TF*MainLayer` class in this project)
- :return: the same class object, with modifications for Keras deserialization.
- """
- initializer = cls.__init__
- config_class = getattr(cls, "config_class", None)
- if config_class is None:
- raise AttributeError("Must set `config_class` to use @keras_serializable")
- @functools.wraps(initializer)
- def wrapped_init(self, *args, **kwargs):
- transformers_config = kwargs.pop("transformers_config", None)
- config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
- if config is not None and transformers_config is not None:
- raise ValueError("Must pass either `config` or `transformers_config`, not both")
- elif config is not None:
- # normal layer construction, call with unchanged args (config is already in there)
- initializer(self, *args, **kwargs)
- elif transformers_config is not None:
- # Keras deserialization, convert dict to config
- config = config_class.from_dict(transformers_config)
- initializer(self, config, *args, **kwargs)
- else:
- raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
- self._transformers_config = config
- cls.__init__ = wrapped_init
- if not hasattr(cls, "get_config"):
- raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
- if hasattr(cls.get_config, "_is_default"):
- def get_config(self):
- cfg = super(cls, self).get_config()
- cfg["transformers_config"] = self._transformers_config.to_dict()
- return cfg
- cls.get_config = get_config
- cls._keras_serializable = True
- if hasattr(tf.keras.utils, "register_keras_serializable"):
- cls = tf.keras.utils.register_keras_serializable()(cls)
- return cls
- class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
- r""" Base class for all TF models.
- :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
- as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
- Class attributes (overridden by derived classes):
- - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- - ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
- - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
- - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
- - ``path``: a path (string) to the TensorFlow checkpoint.
- - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
- """
- config_class = None
- pretrained_model_archive_map = {}
- base_model_prefix = ""
- @property
- def dummy_inputs(self):
- """ Dummy inputs to build the network.
- Returns:
- tf.Tensor with dummy inputs
- """
- return {"input_ids": tf.constant(DUMMY_INPUTS)}
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(*inputs, **kwargs)
- if not isinstance(config, PretrainedConfig):
- raise ValueError(
- "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
- "To create a model from a pretrained model use "
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
- self.__class__.__name__, self.__class__.__name__
- )
- )
- # Save config in model
- self.config = config
- def get_input_embeddings(self):
- """
- Returns the model's input embeddings.
- Returns:
- :obj:`tf.keras.layers.Layer`:
- A torch module mapping vocabulary to hidden states.
- """
- base_model = getattr(self, self.base_model_prefix, self)
- if base_model is not self:
- return base_model.get_input_embeddings()
- else:
- raise NotImplementedError
- def get_output_embeddings(self):
- """
- Returns the model's output embeddings.
- Returns:
- :obj:`tf.keras.layers.Layer`:
- A torch module mapping hidden states to vocabulary.
- """
- return None # Overwrite for models with output embeddings
- def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
- """ Build a resized Embedding Variable from a provided token Embedding Module.
- Increasing the size will add newly initialized vectors at the end
- Reducing the size will remove vectors from the end
- Args:
- new_num_tokens: (`optional`) int
- New number of tokens in the embedding matrix.
- Increasing the size will add newly initialized vectors at the end
- Reducing the size will remove vectors from the end
- If not provided or None: return the provided token Embedding Module.
- Return: ``tf.Variable``
- Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
- """
- # if new_num_tokens is None:
- # return old_embeddings
- # old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
- # if old_num_tokens == new_num_tokens:
- # return old_embeddings
- # # Build new embeddings
- # new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
- # new_embeddings.to(old_embeddings.weight.device)
- # # initialize all new embeddings (in particular added tokens)
- # self._init_weights(new_embeddings)
- # # Copy token embeddings from the previous weights
- # num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
- # new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
- # return new_embeddings
- def resize_token_embeddings(self, new_num_tokens=None):
- """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
- Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
- Arguments:
- new_num_tokens: (`optional`) int:
- New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
- If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
- Return: ``tf.Variable``
- Pointer to the input tokens Embeddings Module of the model
- """
- raise NotImplementedError
- def prune_heads(self, heads_to_prune):
- """ Prunes heads of the base model.
- Arguments:
- heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
- """
- raise NotImplementedError
- def save_pretrained(self, save_directory):
- """ Save a model and its configuration file to a directory, so that it
- can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.
- """
- if os.path.isfile(save_directory):
- log("Provided path ({}) should be a directory, not a file".format(save_directory))
- return
- os.makedirs(save_directory, exist_ok=True)
- # Save configuration file
- self.config.save_pretrained(save_directory)
- # If we save using the predefined names, we can load using `from_pretrained`
- output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
- self.save_weights(output_model_file)
- with h5py.File(output_model_file, "r") as f:
- if "layer_names" not in f.attrs and "model_weights" in f:
- f = f["model_weights"]
- hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
- log(f"Model weights saved in {output_model_file}: {hdf5_layer_names}")
- @classmethod
- def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
- r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
- The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
- It is up to you to train those weights with a downstream fine-tuning task.
- The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
- Parameters:
- pretrained_model_name_or_path: either:
- - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- model_args: (`optional`) Sequence of positional arguments:
- All remaning positional arguments will be passed to the underlying model's ``__init__`` method
- config: (`optional`) one of:
- - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
- - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
- Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
- from_pt: (`optional`) boolean, default False:
- Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).
- cache_dir: (`optional`) string:
- Path to a directory in which a downloaded pre-trained model
- configuration should be cached if the standard cache should not be used.
- force_download: (`optional`) boolean, default False:
- Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
- resume_download: (`optional`) boolean, default False:
- Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
- proxies: (`optional`) dict, default None:
- A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
- The proxies are used on each request.
- output_loading_info: (`optional`) boolean:
- Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
- kwargs: (`optional`) Remaining dictionary of keyword arguments:
- Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
- Examples::
- # For example purposes. Not runnable.
- model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
- model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
- model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
- assert model.config.output_attention == True
- # Loading from a TF checkpoint file instead of a PyTorch model (slower)
- config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
- model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)
- """
- config = kwargs.pop("config", None)
- cache_dir = kwargs.pop("cache_dir", None)
- from_pt = kwargs.pop("from_pt", False)
- force_download = kwargs.pop("force_download", False)
- resume_download = kwargs.pop("resume_download", False)
- proxies = kwargs.pop("proxies", None)
- output_loading_info = kwargs.pop("output_loading_info", False)
- # Load config if we don't provide a configuration
- if not isinstance(config, PretrainedConfig):
- config_path = config if config is not None else pretrained_model_name_or_path
- config, model_kwargs = cls.config_class.from_pretrained(
- config_path,
- *model_args,
- cache_dir=cache_dir,
- return_unused_kwargs=True,
- force_download=force_download,
- resume_download=resume_download,
- **kwargs,
- )
- else:
- model_kwargs = kwargs
- # Load model
- if pretrained_model_name_or_path is not None:
- if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
- archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
- elif os.path.isdir(pretrained_model_name_or_path):
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
- # Load from a TF 2.0 checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
- elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
- # Load from a PyTorch checkpoint
- archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
- else:
- raise EnvironmentError(
- "Error no file named {} found in directory {} or `from_pt` set to False".format(
- [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path
- )
- )
- elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
- archive_file = pretrained_model_name_or_path
- elif os.path.isfile(pretrained_model_name_or_path + ".index"):
- archive_file = pretrained_model_name_or_path + ".index"
- else:
- archive_file = hf_bucket_url(
- pretrained_model_name_or_path, postfix=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME)
- )
- # redirect to the cache, if necessary
- try:
- resolved_archive_file = cached_path(
- archive_file,
- cache_dir=cache_dir,
- force_download=force_download,
- resume_download=resume_download,
- proxies=proxies,
- )
- except EnvironmentError as e:
- if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
- log("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
- else:
- log(
- "Model name '{}' was not found in model name list ({}). "
- "We assumed '{}' was a path or url but couldn't find any file "
- "associated to this path or url.".format(
- pretrained_model_name_or_path,
- ", ".join(cls.pretrained_model_archive_map.keys()),
- archive_file,
- )
- )
- raise e
- if resolved_archive_file == archive_file:
- log("loading weights file {}".format(archive_file))
- else:
- log("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file))
- else:
- resolved_archive_file = None
- # Instantiate model.
- model = cls(config, *model_args, **model_kwargs)
- if from_pt:
- # Load from a PyTorch checkpoint
- raise NotImplementedError
- # return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True)
- model(model.dummy_inputs, training=False) # build the network with dummy inputs
- assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
- # 'by_name' allow us to do transfer learning by skipping/adding layers
- # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
- try:
- model.load_weights(resolved_archive_file, by_name=True)
- except OSError:
- raise OSError(
- "Unable to load weights from h5 file. "
- "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. "
- )
- model(model.dummy_inputs, training=False) # Make sure restore ops are run
- # Check if the models are the same to output loading information
- with h5py.File(resolved_archive_file, "r") as f:
- if "layer_names" not in f.attrs and "model_weights" in f:
- f = f["model_weights"]
- hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
- model_layer_names = set(layer.name for layer in model.layers)
- missing_keys = list(model_layer_names - hdf5_layer_names)
- unexpected_keys = list(hdf5_layer_names - model_layer_names)
- error_msgs = []
- if len(unexpected_keys) > 0:
- log(
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
- f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
- )
- else:
- log(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
- if len(missing_keys) > 0:
- log(
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
- f"and are newly initialized: {missing_keys}\n"
- )
- else:
- log(
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
- f"If your task is similar to the task the model of the ckeckpoint was trained on, "
- f"you can already use {model.__class__.__name__} for predictions without further training."
- )
- if len(error_msgs) > 0:
- raise RuntimeError(
- "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))
- )
- if output_loading_info:
- loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
- return model, loading_info
- return model
- def prepare_inputs_for_generation(self, inputs, **kwargs):
- return {"inputs": inputs}
- def _do_output_past(self, outputs):
- has_output_past = hasattr(self.config, "output_past") and self.config.output_past
- has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
- if has_output_past and not has_mem_len and len(outputs) > 1:
- return True
- elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
- return True
- return False
- def generate(
- self,
- input_ids=None,
- max_length=None,
- min_length=None,
- do_sample=None,
- early_stopping=None,
- num_beams=None,
- temperature=None,
- top_k=None,
- top_p=None,
- repetition_penalty=None,
- bad_words_ids=None,
- bos_token_id=None,
- pad_token_id=None,
- eos_token_id=None,
- length_penalty=None,
- no_repeat_ngram_size=None,
- num_return_sequences=None,
- attention_mask=None,
- decoder_start_token_id=None,
- ):
- r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
- and beam-search.
- Adapted in part from `Facebook's XLM beam search code`_.
- .. _`Facebook's XLM beam search code`:
- https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529
- Parameters:
- input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)`
- The sequence used as a prompt for the generation. If `None` the method initializes
- it as an empty `torch.LongTensor` of shape `(1,)`.
- max_length: (`optional`) int
- The max length of the sequence to be generated. Between 1 and infinity. Default to 20.
- min_length: (`optional`) int
- The min length of the sequence to be generated. Between 0 and infinity. Default to 0.
- do_sample: (`optional`) bool
- If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
- early_stopping: (`optional`) bool
- if set to `True` beam search is stopped when at least `num_beams` sentences finished per batch. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.
- num_beams: (`optional`) int
- Number of beams for beam search. Must be between 1 and infinity. 1 means no beam search. Default to 1.
- temperature: (`optional`) float
- The value used to module the next token probabilities. Must be strictely positive. Default to 1.0.
- top_k: (`optional`) int
- The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
- top_p: (`optional`) float
- The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
- repetition_penalty: (`optional`) float
- The parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.
- bos_token_id: (`optional`) int
- Beginning of sentence token if no prompt is provided. Default to specicic model bos_token_id or None if it does not exist.
- pad_token_id: (`optional`) int
- Pad token. Defaults to pad_token_id as defined in the models config.
- eos_token_id: (`optional`) int
- EOS token. Defaults to eos_token_id as defined in the models config.
- length_penalty: (`optional`) float
- Exponential penalty to the length. Default to 1.
- no_repeat_ngram_size: (`optional`) int
- If set to int > 0, all ngrams of size `no_repeat_ngram_size` can only occur once.
- bad_words_ids: (`optional`) list of lists of int
- `bad_words_ids` contains tokens that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
- num_return_sequences: (`optional`) int
- The number of independently computed returned sequences for each element in the batch. Default to 1.
- attention_mask (`optional`) obj: `tf.Tensor` with `dtype=tf.int32` of same shape as `input_ids`
- Mask to avoid performing attention on padding token indices.
- Mask values selected in ``[0, 1]``:
- ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
- Defaults to `None`.
- `What are attention masks? <../glossary.html#attention-mask>`__
- decoder_start_token_id=None: (`optional`) int
- If an encoder-decoder model starts decoding with a different token than BOS.
- Defaults to `None` and is changed to `BOS` later.
- Return:
- output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
- sequence_length is either equal to max_length or shorter if all batches finished early due to the `eos_token_id`
- Examples::
- tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
- model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
- outputs = model.generate(max_length=40) # do greedy decoding
- print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
- tokenizer = AutoTokenizer.from_pretrained('openai-gpt') # Initialize tokenizer
- model = TFAutoModelWithLMHead.from_pretrained('openai-gpt') # Download model and configuration from S3 and cache.
- input_context = 'The dog'
- input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
- outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5) # generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
- for i in range(3): # 3 output sequences were generated
- print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
- tokenizer = AutoTokenizer.from_pretrained('distilgpt2') # Initialize tokenizer
- model = TFAutoModelWithLMHead.from_pretrained('distilgpt2') # Download model and configuration from S3 and cache.
- input_context = 'The dog'
- input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
- outputs = model.generate(input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3) # 3 generate sequences using by sampling
- for i in range(3): # 3 output sequences were generated
- print('Generated {}: {}'.format(i, tokenizer.decode(outputs[i], skip_special_tokens=True)))
- tokenizer = AutoTokenizer.from_pretrained('ctrl') # Initialize tokenizer
- model = TFAutoModelWithLMHead.from_pretrained('ctrl') # Download model and configuration from S3 and cache.
- input_context = 'Legal My neighbor is' # "Legal" is one of the control codes for ctrl
- input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
- outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2) # generate sequences
- print('Generated: {}'.format(tokenizer.decode(outputs[0], skip_special_tokens=True)))
- tokenizer = AutoTokenizer.from_pretrained('gpt2') # Initialize tokenizer
- model = TFAutoModelWithLMHead.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
- input_context = 'My cute dog' # "Legal" is one of the control codes for ctrl
- bad_words_ids = [tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ['idiot', 'stupid', 'shut up']]
- input_ids = tokenizer.encode(input_context, return_tensors='tf') # encode input context
- outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids) # generate sequences without allowing bad_words to be generated
- """
- # We cannot generate if the model does not have a LM head
- if self.get_output_embeddings() is None:
- raise AttributeError(
- "You tried to generate sequences with a model that does not have a LM Head."
- "Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
- )
- max_length = max_length if max_length is not None else self.config.max_length
- min_length = min_length if min_length is not None else self.config.min_length
- do_sample = do_sample if do_sample is not None else self.config.do_sample
- early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
- num_beams = num_beams if num_beams is not None else self.config.num_beams
- temperature = temperature if temperature is not None else self.config.temperature
- top_k = top_k if top_k is not None else self.config.top_k
- top_p = top_p if top_p is not None else self.config.top_p
- repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
- bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
- length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
- no_repeat_ngram_size = (
- no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
- )
- bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
- num_return_sequences = (
- num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
- )
- decoder_start_token_id = (
- decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
- )
- if input_ids is not None:
- batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
- else:
- batch_size = 1
- assert isinstance(max_length, int) and max_length > 0, "`max_length` should be a strictely positive integer."
- assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
- assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
- assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
- assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
- assert temperature > 0, "`temperature` should be strictely positive."
- assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
- assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
- assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
- assert input_ids is not None or (
- isinstance(bos_token_id, int) and bos_token_id >= 0
- ), "If input_ids is not defined, `bos_token_id` should be a positive integer."
- assert pad_token_id is None or (
- isinstance(pad_token_id, int) and (pad_token_id >= 0)
- ), "`pad_token_id` should be a positive integer."
- assert (eos_token_id is None) or (
- isinstance(eos_token_id, int) and (eos_token_id >= 0)
- ), "`eos_token_id` should be a positive integer."
- assert length_penalty > 0, "`length_penalty` should be strictely positive."
- assert (
- isinstance(num_return_sequences, int) and num_return_sequences > 0
- ), "`num_return_sequences` should be a strictely positive integer."
- assert (
- bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
- ), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"
- if input_ids is None:
- assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
- "you should either supply a context to complete as `input_ids` input "
- "or a `bos_token_id` (integer >= 0) as a first token to start the generation."
- )
- input_ids = tf.fill((batch_size, 1), bos_token_id)
- else:
- assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
- # not allow to duplicate outputs when greedy decoding
- if do_sample is False:
- if num_beams == 1:
- # no_beam_search greedy generation conditions
- assert (
- num_return_sequences == 1
- ), "Greedy decoding will always produce the same output for num_beams == 1 and num_return_sequences > 1. Please set num_return_sequences = 1"
- else:
- # beam_search greedy generation conditions
- assert (
- num_beams >= num_return_sequences
- ), "Greedy beam search decoding cannot return more sequences than it has beams. Please set num_beams >= num_return_sequences"
- # create attention mask if necessary
- # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
- if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
- attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
- elif attention_mask is None:
- attention_mask = tf.ones_like(input_ids)
- if pad_token_id is None and eos_token_id is not None:
- log(
- "Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
- )
- pad_token_id = eos_token_id
- # current position and vocab size
- cur_len = shape_list(input_ids)[1]
- vocab_size = self.config.vocab_size
- # set effective batch size and effective batch multiplier according to do_sample
- if do_sample:
- effective_batch_size = batch_size * num_return_sequences
- effective_batch_mult = num_return_sequences
- else:
- effective_batch_size = batch_size
- effective_batch_mult = 1
- # Expand input ids if num_beams > 1 or num_return_sequences > 1
- if num_return_sequences > 1 or num_beams > 1:
- input_ids_len = shape_list(input_ids)[-1]
- input_ids = tf.broadcast_to(
- tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
- )
- attention_mask = tf.broadcast_to(
- tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
- )
- input_ids = tf.reshape(
- input_ids, (effective_batch_size * num_beams, input_ids_len)
- ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
- attention_mask = tf.reshape(
- attention_mask, (effective_batch_size * num_beams, input_ids_len)
- ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
- if self.config.is_encoder_decoder:
- if decoder_start_token_id is None:
- decoder_start_token_id = bos_token_id
- assert (
- decoder_start_token_id is not None
- ), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
- assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
- assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
- # get encoder and store encoder outputs
- encoder = self.get_encoder()
- encoder_outputs = encoder(input_ids, attention_mask=attention_mask)
- # create empty decoder_input_ids
- input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
- cur_len = 1
- else:
- encoder_outputs = None
- cur_len = shape_list(input_ids)[-1]
- if num_beams > 1:
- output = self._generate_beam_search(
- input_ids,
- cur_len=cur_len,
- max_length=max_length,
- min_length=min_length,
- do_sample=do_sample,
- early_stopping=early_stopping,
- temperature=temperature,
- top_k=top_k,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- bad_words_ids=bad_words_ids,
- bos_token_id=bos_token_id,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- decoder_start_token_id=decoder_start_token_id,
- batch_size=effective_batch_size,
- num_return_sequences=num_return_sequences,
- length_penalty=length_penalty,
- num_beams=num_beams,
- vocab_size=vocab_size,
- encoder_outputs=encoder_outputs,
- attention_mask=attention_mask,
- )
- else:
- output = self._generate_no_beam_search(
- input_ids,
- cur_len=cur_len,
- max_length=max_length,
- min_length=min_length,
- do_sample=do_sample,
- temperature=temperature,
- top_k=top_k,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- no_repeat_ngram_size=no_repeat_ngram_size,
- bad_words_ids=bad_words_ids,
- bos_token_id=bos_token_id,
- pad_token_id=pad_token_id,
- eos_token_id=eos_token_id,
- decoder_start_token_id=decoder_start_token_id,
- batch_size=effective_batch_size,
- vocab_size=vocab_size,
- encoder_outputs=encoder_outputs,
- attention_mask=attention_mask,
- )
- return output
- def _generate_no_beam_search(
- self,
- input_ids,
- cur_len,
- max_length,
- min_length,
- do_sample,
- temperature,
- top_k,
- top_p,
- repetition_penalty,
- no_repeat_ngram_size,
- bad_words_ids,
- bos_token_id,
- pad_token_id,
- eos_token_id,
- decoder_start_token_id,
- batch_size,
- vocab_size,
- encoder_outputs,
- attention_mask,
- ):
- """ Generate sequences for each example without beam search (num_beams == 1).
- All returned sequence are generated independantly.
- """
- # length of generated sentences / unfinished sentences
- unfinished_sents = tf.ones_like(input_ids[:, 0])
- sent_lengths = tf.ones_like(input_ids[:, 0]) * max_length
- past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
- while cur_len < max_length:
- model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
- outputs = self(**model_inputs)
- next_token_logits = outputs[0][:, -1, :]
- # if model has past, then set the past variable to speed up decoding
- if self._do_output_past(outputs):
- past = outputs[1]
- # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
- if repetition_penalty != 1.0:
- next_token_logits_penalties = _create_next_token_logits_penalties(
- input_ids, next_token_logits, repetition_penalty
- )
- next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
- if no_repeat_ngram_size > 0:
- # calculate a list of banned tokens to prevent repetitively generating the same ngrams
- # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
- banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
- # create banned_tokens boolean mask
- banned_tokens_indices_mask = []
- for banned_tokens_slice in banned_tokens:
- banned_tokens_indices_mask.append(
- [True if token in banned_tokens_slice else False for token in range(vocab_size)]
- )
- next_token_logits = set_tensor_by_indices_to_value(
- next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
- )
- if bad_words_ids is not None:
- # calculate a list of banned tokens according to bad words
- banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
- banned_tokens_indices_mask = []
- for banned_tokens_slice in banned_tokens:
- banned_tokens_indices_mask.append(
- [True if token in banned_tokens_slice else False for token in range(vocab_size)]
- )
- next_token_logits = set_tensor_by_indices_to_value(
- next_token_logits, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
- )
- # set eos token prob to zero if min_length is not reached
- if eos_token_id is not None and cur_len < min_length:
- # create eos_token_id boolean mask
- is_token_logit_eos_token = tf.convert_to_tensor(
- [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
- )
- eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [batch_size, vocab_size])
- next_token_logits = set_tensor_by_indices_to_value(
- next_token_logits, eos_token_indices_mask, -float("inf")
- )
- if do_sample:
- # Temperature (higher temperature => more likely to sample low probability tokens)
- if temperature != 1.0:
- next_token_logits = next_token_logits / temperature
- # Top-p/top-k filtering
- next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
- # Sample
- next_token = tf.squeeze(
- tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=1), axis=1
- )
- else:
- # Greedy decoding
- next_token = tf.math.argmax(next_token_logits, axis=-1, output_type=tf.int32)
- # update generations and finished sentences
- if eos_token_id is not None:
- # pad finished sentences if eos_token_id exist
- tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
- else:
- tokens_to_add = next_token
- input_ids = tf.concat([input_ids, tf.expand_dims(tokens_to_add, -1)], 1)
- if eos_token_id is not None:
- eos_in_sents = tokens_to_add == eos_token_id
- # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
- is_sents_unfinished_and_token_to_add_is_eos = tf.math.multiply(
- unfinished_sents, tf.cast(eos_in_sents, tf.int32)
- )
- sent_lengths = (
- sent_lengths * (1 - is_sents_unfinished_and_token_to_add_is_eos)
- + cur_len * is_sents_unfinished_and_token_to_add_is_eos
- )
- # unfinished_sents is set to zero if eos in sentence
- unfinished_sents -= is_sents_unfinished_and_token_to_add_is_eos
- # stop when there is a </s> in each sentence, or if we exceed the maximul length
- if tf.math.reduce_max(unfinished_sents) == 0:
- break
- # extend attention_mask for new generated input if only decoder
- if self.config.is_encoder_decoder is False:
- attention_mask = tf.concat(
- [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
- )
- cur_len = cur_len + 1
- # if there are different sentences lengths in the batch, some batches have to be padded
- min_sent_length = tf.math.reduce_min(sent_lengths)
- max_sent_length = tf.math.reduce_max(sent_lengths)
- if min_sent_length != max_sent_length:
- assert pad_token_id is not None, "`Pad_token_id` has to be defined if batches have different lengths"
- # finished sents are filled with pad_token
- padding = tf.ones([batch_size, max_sent_length.numpy()], dtype=tf.int32) * pad_token_id
- # create length masks for tf.where operation
- broad_casted_sent_lengths = tf.broadcast_to(
- tf.expand_dims(sent_lengths, -1), [batch_size, max_sent_length]
- )
- broad_casted_range = tf.transpose(
- tf.broadcast_to(tf.expand_dims(tf.range(max_length), -1), [max_length, batch_size])
- )
- decoded = tf.where(broad_casted_range < broad_casted_sent_lengths, input_ids, padding)
- else:
- decoded = input_ids
- return decoded
- def _generate_beam_search(
- self,
- input_ids,
- cur_len,
- max_length,
- min_length,
- do_sample,
- early_stopping,
- temperature,
- top_k,
- top_p,
- repetition_penalty,
- no_repeat_ngram_size,
- bad_words_ids,
- bos_token_id,
- pad_token_id,
- decoder_start_token_id,
- eos_token_id,
- batch_size,
- num_return_sequences,
- length_penalty,
- num_beams,
- vocab_size,
- encoder_outputs,
- attention_mask,
- ):
- """ Generate sequences for each example with beam search.
- """
- # generated hypotheses
- generated_hyps = [
- BeamHypotheses(num_beams, max_length, length_penalty, early_stopping=early_stopping)
- for _ in range(batch_size)
- ]
- # for greedy decoding it is made sure that only tokens of the first beam are considered to avoid sampling the exact same tokens three times
- if do_sample is False:
- beam_scores_begin = tf.zeros((batch_size, 1), dtype=tf.float32)
- beam_scores_end = tf.ones((batch_size, num_beams - 1), dtype=tf.float32) * (-1e9)
- beam_scores = tf.concat([beam_scores_begin, beam_scores_end], -1)
- else:
- beam_scores = tf.zeros((batch_size, num_beams), dtype=tf.float32)
- beam_scores = tf.reshape(beam_scores, (batch_size * num_beams,))
- # cache compute states
- past = encoder_outputs
- # done sentences
- done = [False for _ in range(batch_size)]
- while cur_len < max_length:
- model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
- outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
- next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
- # if model has past, then set the past variable to speed up decoding
- if self._do_output_past(outputs):
- past = outputs[1]
- # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
- if repetition_penalty != 1.0:
- next_token_logits_penalties = _create_next_token_logits_penalties(
- input_ids, next_token_logits, repetition_penalty
- )
- next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
- # Temperature (higher temperature => more likely to sample low probability tokens)
- if temperature != 1.0:
- next_token_logits = next_token_logits / temperature
- # calculate log softmax score
- scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
- # set eos token prob to zero if min_length is not reached
- if eos_token_id is not None and cur_len < min_length:
- # create eos_token_id boolean mask
- num_batch_hypotheses = batch_size * num_beams
- is_token_logit_eos_token = tf.convert_to_tensor(
- [True if token is eos_token_id else False for token in range(vocab_size)], dtype=tf.bool
- )
- eos_token_indices_mask = tf.broadcast_to(is_token_logit_eos_token, [num_batch_hypotheses, vocab_size])
- scores = set_tensor_by_indices_to_value(scores, eos_token_indices_mask, -float("inf"))
- if no_repeat_ngram_size > 0:
- # calculate a list of banned tokens to prevent repetitively generating the same ngrams
- # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
- num_batch_hypotheses = batch_size * num_beams
- banned_tokens = calc_banned_ngram_tokens(
- input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
- )
- # create banned_tokens boolean mask
- banned_tokens_indices_mask = []
- for banned_tokens_slice in banned_tokens:
- banned_tokens_indices_mask.append(
- [True if token in banned_tokens_slice else False for token in range(vocab_size)]
- )
- scores = set_tensor_by_indices_to_value(
- scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
- )
- if bad_words_ids is not None:
- # calculate a list of banned tokens according to bad words
- banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
- banned_tokens_indices_mask = []
- for banned_tokens_slice in banned_tokens:
- banned_tokens_indices_mask.append(
- [True if token in banned_tokens_slice else False for token in range(vocab_size)]
- )
- scores = set_tensor_by_indices_to_value(
- scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
- )
- assert shape_list(scores) == [batch_size * num_beams, vocab_size]
- if do_sample:
- _scores = scores + tf.broadcast_to(
- beam_scores[:, None], (batch_size * num_beams, vocab_size)
- ) # (batch_size * num_beams, vocab_size)
- # Top-p/top-k filtering
- _scores = tf_top_k_top_p_filtering(
- _scores, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
- ) # (batch_size * num_beams, vocab_size)
- # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
- _scores = tf.reshape(_scores, (batch_size, num_beams * vocab_size))
- next_tokens = tf.random.categorical(
- _scores, dtype=tf.int32, num_samples=2 * num_beams
- ) # (batch_size, 2 * num_beams)
- # Compute next scores
- next_scores = tf.gather(_scores, next_tokens, batch_dims=1) # (batch_size, 2 * num_beams)
- # sort the sampled vector to make sure that the first num_beams samples are the best
- next_scores_indices = tf.argsort(next_scores, direction="DESCENDING", axis=1)
- next_scores = tf.gather(next_scores, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
- next_tokens = tf.gather(next_tokens, next_scores_indices, batch_dims=1) # (batch_size, num_beams * 2)
- else:
- # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
- next_scores = scores + tf.broadcast_to(
- beam_scores[:, None], (batch_size * num_beams, vocab_size)
- ) # (batch_size * num_beams, vocab_size)
- # re-organize to group the beam together (we are keeping top hypothesis accross beams)
- next_scores = tf.reshape(
- next_scores, (batch_size, num_beams * vocab_size)
- ) # (batch_size, num_beams * vocab_size)
- next_scores, next_tokens = tf.math.top_k(next_scores, k=2 * num_beams, sorted=True)
- assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
- # next batch beam content
- next_batch_beam = []
- # for each sentence
- for batch_idx in range(batch_size):
- # if we are done with this sentence
- if done[batch_idx]:
- assert (
- len(generated_hyps[batch_idx]) >= num_beams
- ), "Batch can only be done if at least {} beams have been generated".format(num_beams)
- assert (
- eos_token_id is not None and pad_token_id is not None
- ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
- next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
- continue
- # next sentence beam content
- next_sent_beam = []
- # next tokens for this sentence
- for beam_token_rank, (beam_token_id, beam_token_score) in enumerate(
- zip(next_tokens[batch_idx], next_scores[batch_idx])
- ):
- # get beam and token IDs
- beam_id = beam_token_id // vocab_size
- token_id = beam_token_id % vocab_size
- effective_beam_id = batch_idx * num_beams + beam_id
- # add to generated hypotheses if end of sentence or last iteration
- if (eos_token_id is not None) and (token_id.numpy() == eos_token_id):
- # if beam_token does not belong to top num_beams tokens, it should not be added
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
- if is_beam_token_worse_than_top_num_beams:
- continue
- generated_hyps[batch_idx].add(
- tf.identity(input_ids[effective_beam_id]), beam_token_score.numpy()
- )
- else:
- # add next predicted token if it is not eos_token
- next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
- # the beam for next step is full
- if len(next_sent_beam) == num_beams:
- break
- # Check if were done so that we can save a pad step if all(done)
- done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
- tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len
- )
- # update next beam content
- assert len(next_sent_beam) == num_beams, "Beam should always be full"
- next_batch_beam.extend(next_sent_beam)
- assert len(next_batch_beam) == num_beams * (batch_idx + 1)
- # stop when we are done with each sentence
- if all(done):
- break
- # sanity check / prepare next batch
- assert len(next_batch_beam) == batch_size * num_beams
- beam_scores = tf.convert_to_tensor([x[0] for x in next_batch_beam], dtype=tf.float32)
- beam_tokens = tf.convert_to_tensor([x[1] for x in next_batch_beam], dtype=tf.int32)
- beam_idx = tf.convert_to_tensor([x[2] for x in next_batch_beam], dtype=tf.int32)
- # re-order batch
- input_ids = tf.stack([tf.identity(input_ids[x, :]) for x in beam_idx])
- input_ids = tf.concat([input_ids, tf.expand_dims(beam_tokens, 1)], axis=-1)
- # re-order internal states
- if past is not None:
- past = self._reorder_cache(past, beam_idx)
- # extend attention_mask for new generated input if only decoder
- if self.config.is_encoder_decoder is False:
- attention_mask = tf.concat(
- [attention_mask, tf.ones((shape_list(attention_mask)[0], 1), dtype=tf.int32)], axis=-1
- )
- # update current length
- cur_len = cur_len + 1
- # finalize all open beam hypotheses and end to generated hypotheses
- for batch_idx in range(batch_size):
- # Add all open beam hypothesis to generated_hyps
- if done[batch_idx]:
- continue
- # test that beam scores match previously calculated scores if not eos and batch_idx not done
- if eos_token_id is not None and all(
- (token_id % vocab_size).numpy().item() is not eos_token_id for token_id in next_tokens[batch_idx]
- ):
- assert tf.reduce_all(
- next_scores[batch_idx, :num_beams] == tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
- ), "If batch_idx is not done, final next scores: {} have to equal to accumulated beam_scores: {}".format(
- next_scores[:, :num_beams][batch_idx], tf.reshape(beam_scores, (batch_size, num_beams))[batch_idx]
- )
- # need to add best num_beams hypotheses to generated hyps
- for beam_id in range(num_beams):
- effective_beam_id = batch_idx * num_beams + beam_id
- final_score = beam_scores[effective_beam_id].numpy().item()
- final_tokens = input_ids[effective_beam_id]
- generated_hyps[batch_idx].add(final_tokens, final_score)
- # depending on whether greedy generation is wanted or not define different output_batch_size and output_num_return_sequences_per_batch
- output_batch_size = batch_size if do_sample else batch_size * num_return_sequences
- output_num_return_sequences_per_batch = 1 if do_sample else num_return_sequences
- # select the best hypotheses
- sent_lengths_list = []
- best = []
- # retrieve best hypotheses
- for i, hypotheses in enumerate(generated_hyps):
- sorted_hyps = sorted(hypotheses.beams, key=lambda x: x[0])
- for j in range(output_num_return_sequences_per_batch):
- best_hyp = sorted_hyps.pop()[1]
- sent_lengths_list.append(len(best_hyp))
- best.append(best_hyp)
- assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
- output_batch_size, len(best)
- )
- sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
- # shorter batches are filled with pad_token
- if tf.reduce_min(sent_lengths).numpy() != tf.reduce_max(sent_lengths).numpy():
- assert pad_token_id is not None, "`Pad_token_id` has to be defined"
- sent_max_len = min(tf.reduce_max(sent_lengths).numpy() + 1, max_length)
- decoded_list = []
- # fill with hypothesis and eos_token_id if necessary
- for i, hypo in enumerate(best):
- assert sent_lengths[i] == shape_list(hypo)[0]
- # if sent_length is max_len do not pad
- if sent_lengths[i] == sent_max_len:
- decoded_slice = hypo
- else:
- # else pad to sent_max_len
- num_pad_tokens = sent_max_len - sent_lengths[i]
- padding = pad_token_id * tf.ones((num_pad_tokens,), dtype=tf.int32)
- decoded_slice = tf.concat([hypo, padding], axis=-1)
- # finish sentence with EOS token
- if sent_lengths[i] < max_length:
- decoded_slice = tf.where(
- tf.range(sent_max_len, dtype=tf.int32) == sent_lengths[i],
- eos_token_id * tf.ones((sent_max_len,), dtype=tf.int32),
- decoded_slice,
- )
- # add to list
- decoded_list.append(decoded_slice)
- decoded = tf.stack(decoded_list)
- else:
- # none of the hypotheses have an eos_token
- assert (len(hypo) == max_length for hypo in best)
- decoded = tf.stack(best)
- return decoded
- @staticmethod
- def _reorder_cache(past, beam_idx):
- reordered_past = []
- for layer_past in past:
- # get the correct batch idx from layer past batch dim
- # batch dim of `past` and `mems` is at 2nd position
- reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
- reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
- # check that shape matches
- assert shape_list(reordered_layer_past) == shape_list(layer_past)
- reordered_past.append(reordered_layer_past)
- past = tuple(reordered_past)
- return past
- def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
- # create logit penalties for already seen input_ids
- token_penalties = np.ones(shape_list(logits))
- prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
- for i, prev_input_id in enumerate(prev_input_ids):
- logit_penalized = logits[i].numpy()[prev_input_id]
- logit_penalties = np.zeros(logit_penalized.shape)
- # if previous logit score is < 0 then multiply repetition penalty else divide
- logit_penalties[logit_penalized < 0] = repetition_penalty
- logit_penalties[logit_penalized > 0] = 1 / repetition_penalty
- np.put(token_penalties[i], prev_input_id, logit_penalties)
- return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
- def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
- # Copied from fairseq for no_repeat_ngram in beam_search"""
- if cur_len + 1 < no_repeat_ngram_size:
- # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
- return [[] for _ in range(num_hypos)]
- generated_ngrams = [{} for _ in range(num_hypos)]
- for idx in range(num_hypos):
- gen_tokens = prev_input_ids[idx].numpy().tolist()
- generated_ngram = generated_ngrams[idx]
- for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
- prev_ngram_tuple = tuple(ngram[:-1])
- generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
- def _get_generated_ngrams(hypo_idx):
- # Before decoding the next token, prevent decoding of ngrams that have already appeared
- start_idx = cur_len + 1 - no_repeat_ngram_size
- ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
- return generated_ngrams[hypo_idx].get(ngram_idx, [])
- banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
- return banned_tokens
- def calc_banned_bad_words_ids(prev_input_ids, bad_words_ids):
- banned_tokens = []
- def _tokens_match(prev_tokens, tokens):
- if len(tokens) == 0:
- # if bad word tokens is just one token always ban it
- return True
- if len(tokens) > len(prev_input_ids):
- # if bad word tokens are longer then prev input_ids they can't be equal
- return False
- if prev_tokens[-len(tokens) :] == tokens:
- # if tokens match
- return True
- else:
- return False
- for prev_input_ids_slice in prev_input_ids:
- banned_tokens_slice = []
- for banned_token_seq in bad_words_ids:
- assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
- bad_words_ids
- )
- if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
- # if tokens do not match continue
- continue
- banned_tokens_slice.append(banned_token_seq[-1])
- banned_tokens.append(banned_tokens_slice)
- return banned_tokens
- def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1):
- """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
- Args:
- logits: logits distribution shape (batch size, vocabulary size)
- if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
- if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
- Make sure we keep at least min_tokens_to_keep per batch example in the output
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
- """
- logits_shape = shape_list(logits)
- if top_k > 0:
- top_k = min(max(top_k, min_tokens_to_keep), logits_shape[-1]) # Safety check
- # Remove all tokens with a probability less than the last token of the top-k
- indices_to_remove = logits < tf.math.top_k(logits, k=top_k)[0][..., -1, None]
- logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
- if top_p < 1.0:
- sorted_indices = tf.argsort(logits, direction="DESCENDING")
- sorted_logits = tf.gather(
- logits, sorted_indices, axis=-1, batch_dims=1
- ) # expects logits to be of dim (batch_size, vocab_size)
- cumulative_probs = tf.math.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
- # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
- sorted_indices_to_remove = cumulative_probs > top_p
- if min_tokens_to_keep > 1:
- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
- sorted_indices_to_remove = tf.concat(
- [
- tf.zeros_like(sorted_indices_to_remove[:, :min_tokens_to_keep]),
- sorted_indices_to_remove[:, min_tokens_to_keep:],
- ],
- -1,
- )
- # Shift the indices to the right to keep also the first token above the threshold
- sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
- sorted_indices_to_remove = tf.concat(
- [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], -1,
- )
- # scatter sorted tensors to original indexing
- indices_to_remove = scatter_values_on_batch_indices(sorted_indices_to_remove, sorted_indices)
- logits = set_tensor_by_indices_to_value(logits, indices_to_remove, filter_value)
- return logits
- def scatter_values_on_batch_indices(values, batch_indices):
- shape = shape_list(batch_indices)
- # broadcast batch dim to shape
- broad_casted_batch_dims = tf.reshape(tf.broadcast_to(tf.expand_dims(tf.range(shape[0]), axis=-1), shape), [1, -1])
- # transform batch_indices to pair_indices
- pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
- # scatter values to pair indices
- return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), shape)
- def set_tensor_by_indices_to_value(tensor, indices, value):
- # create value_tensor since tensor value assignment is not possible in TF
- value_tensor = tf.zeros_like(tensor) + value
- return tf.where(indices, value_tensor, tensor)
- class BeamHypotheses(object):
- def __init__(self, num_beams, max_length, length_penalty, early_stopping):
- """
- Initialize n-best list of hypotheses.
- """
- self.max_length = max_length - 1 # ignoring bos_token
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- self.num_beams = num_beams
- self.beams = []
- self.worst_score = 1e9
- def __len__(self):
- """
- Number of hypotheses in the list.
- """
- return len(self.beams)
- def add(self, hyp, sum_logprobs):
- """
- Add a new hypothesis to the list.
- """
- score = sum_logprobs / len(hyp) ** self.length_penalty
- if len(self) < self.num_beams or score > self.worst_score:
- self.beams.append((score, hyp))
- if len(self) > self.num_beams:
- sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
- del self.beams[sorted_scores[0][1]]
- self.worst_score = sorted_scores[1][0]
- else:
- self.worst_score = min(score, self.worst_score)
- def is_done(self, best_sum_logprobs, cur_len=None):
- """
- If there are enough hypotheses and that none of the hypotheses being generated
- can become better than the worst one in the heap, then we are done with this sentence.
- """
- if len(self) < self.num_beams:
- return False
- elif self.early_stopping:
- return True
- else:
- if cur_len is None:
- cur_len = self.max_length
- cur_score = best_sum_logprobs / cur_len ** self.length_penalty
- ret = self.worst_score >= cur_score
- return ret
- class TFConv1D(tf.keras.layers.Layer):
- def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
- """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
- Basically works like a Linear layer but the weights are transposed
- """
- super().__init__(**kwargs)
- self.nf = nf
- self.nx = nx
- self.initializer_range = initializer_range
- def build(self, input_shape):
- self.weight = self.add_weight(
- "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range)
- )
- self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer())
- def call(self, x):
- bz, sl = shape_list(x)[:2]
- x = tf.reshape(x, [-1, self.nx])
- x = tf.matmul(x, self.weight) + self.bias
- x = tf.reshape(x, [bz, sl, self.nf])
- return x
- class TFSharedEmbeddings(tf.keras.layers.Layer):
- """Construct shared token embeddings.
- """
- def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
- super().__init__(**kwargs)
- self.vocab_size = vocab_size
- self.hidden_size = hidden_size
- self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
- def build(self, input_shape):
- """Build shared token embedding layer
- Shared weights logic adapted from
- https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
- """
- self.weight = self.add_weight(
- "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
- )
- super().build(input_shape)
- def call(self, inputs, mode="embedding"):
- """Get token embeddings of inputs.
- Args:
- inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
- mode: string, a valid value is one of "embedding" and "linear".
- Returns:
- outputs: (1) If mode == "embedding", output embedding tensor, float32 with
- shape [batch_size, length, embedding_size]; (2) mode == "linear", output
- linear tensor, float32 with shape [batch_size, length, vocab_size].
- Raises:
- ValueError: if mode is not valid.
- Shared weights logic adapted from
- https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
- """
- if mode == "embedding":
- return self._embedding(inputs)
- elif mode == "linear":
- return self._linear(inputs)
- else:
- raise ValueError("mode {} is not valid.".format(mode))
- def _embedding(self, input_ids):
- """Applies embedding based on inputs tensor."""
- return tf.gather(self.weight, input_ids)
- def _linear(self, inputs):
- """Computes logits by running inputs through a linear layer.
- Args:
- inputs: A float32 tensor with shape [..., hidden_size]
- Returns:
- float32 tensor with shape [..., vocab_size].
- """
- first_dims = shape_list(inputs)[:-1]
- x = tf.reshape(inputs, [-1, self.hidden_size])
- logits = tf.matmul(x, self.weight, transpose_b=True)
- return tf.reshape(logits, first_dims + [self.vocab_size])
- class TFSequenceSummary(tf.keras.layers.Layer):
- r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
- Args of the config class:
- summary_type:
- - 'last' => [default] take the last token hidden state (like XLNet)
- - 'first' => take the first token hidden state (like Bert)
- - 'mean' => take the mean of all tokens hidden states
- - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
- - 'attn' => Not implemented now, use multi-head attention
- summary_use_proj: Add a projection after the vector extraction
- summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
- summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
- summary_first_dropout: Add a dropout before the projection and activation
- summary_last_dropout: Add a dropout after the projection and activation
- """
- def __init__(self, config, initializer_range=0.02, **kwargs):
- super().__init__(**kwargs)
- self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last"
- if self.summary_type == "attn":
- # We should use a standard multi-head attention module with absolute positional embedding for that.
- # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
- # We can probably just use the multi-head attention module of PyTorch >=1.1.0
- raise NotImplementedError
- self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj
- if self.has_summary:
- if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
- num_classes = config.num_labels
- else:
- num_classes = config.hidden_size
- self.summary = tf.keras.layers.Dense(
- num_classes, kernel_initializer=get_initializer(initializer_range), name="summary"
- )
- self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh"
- if self.has_activation:
- self.activation = tf.keras.activations.tanh
- self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0
- if self.has_first_dropout:
- self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
- self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0
- if self.has_last_dropout:
- self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
- def call(self, inputs, training=False):
- """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
- cls_index: [optional] position of the classification token if summary_type == 'cls_index',
- shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
- if summary_type == 'cls_index' and cls_index is None:
- we take the last token of the sequence as classification token
- """
- if not isinstance(inputs, (dict, tuple, list)):
- hidden_states = inputs
- cls_index = None
- elif isinstance(inputs, (tuple, list)):
- hidden_states = inputs[0]
- cls_index = inputs[1] if len(inputs) > 1 else None
- assert len(inputs) <= 2, "Too many inputs."
- else:
- hidden_states = inputs.get("hidden_states")
- cls_index = inputs.get("cls_index", None)
- if self.summary_type == "last":
- output = hidden_states[:, -1]
- elif self.summary_type == "first":
- output = hidden_states[:, 0]
- elif self.summary_type == "mean":
- output = tf.reduce_mean(hidden_states, axis=1)
- elif self.summary_type == "cls_index":
- hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
- if cls_index is None:
- cls_index = tf.fill(
- hidden_shape[:-2], hidden_shape[-2] - 1
- ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
- cls_shape = shape_list(cls_index)
- if len(cls_shape) <= len(hidden_shape) - 2:
- cls_index = cls_index[..., tf.newaxis]
- # else:
- # cls_index = cls_index[..., tf.newaxis]
- # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
- # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
- output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
- output = tf.squeeze(
- output, axis=len(hidden_shape) - 2
- ) # shape of output: (batch, num choices, hidden_size)
- elif self.summary_type == "attn":
- raise NotImplementedError
- if self.has_first_dropout:
- output = self.first_dropout(output, training=training)
- if self.has_summary:
- output = self.summary(output)
- if self.has_activation:
- output = self.activation(output)
- if self.has_last_dropout:
- output = self.last_dropout(output, training=training)
- return output
- def shape_list(x):
- """Deal with dynamic shape in tensorflow cleanly."""
- static = x.shape.as_list()
- dynamic = tf.shape(x)
- return [dynamic[i] if s is None else s for i, s in enumerate(static)]
- def get_initializer(initializer_range=0.02):
- """Creates a `tf.initializers.truncated_normal` with the given range.
- Args:
- initializer_range: float, initializer range for stddev.
- Returns:
- TruncatedNormal initializer with stddev = `initializer_range`.
- """
- return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
- TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
- "bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-tf_model.h5",
- "bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-tf_model.h5",
- "bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-tf_model.h5",
- "bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-tf_model.h5",
- "bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-tf_model.h5",
- "bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-tf_model.h5",
- "bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-tf_model.h5",
- "bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-tf_model.h5",
- "bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-tf_model.h5",
- "bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-tf_model.h5",
- "bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
- "bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
- "bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-tf_model.h5",
- "bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-tf_model.h5",
- "bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-tf_model.h5",
- "bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-tf_model.h5",
- "bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-tf_model.h5",
- "bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
- "bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
- "bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/tf_model.h5",
- }
- def gelu(x):
- """ Gaussian Error Linear Unit.
- Original Implementation of the gelu activation function in Google Bert repo when initially created.
- For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
- 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
- Also see https://arxiv.org/abs/1606.08415
- """
- cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
- return x * cdf
- def gelu_new(x):
- """Gaussian Error Linear Unit.
- This is a smoother version of the RELU.
- Original paper: https://arxiv.org/abs/1606.08415
- Args:
- x: float Tensor to perform activation.
- Returns:
- `x` with the GELU activation applied.
- """
- cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
- return x * cdf
- def swish(x):
- return x * tf.sigmoid(x)
- ACT2FN = {
- "gelu": tf.keras.layers.Activation(gelu),
- "relu": tf.keras.activations.relu,
- "swish": tf.keras.layers.Activation(swish),
- "gelu_new": tf.keras.layers.Activation(gelu_new),
- }
- class TFBertEmbeddings(tf.keras.layers.Layer):
- """Construct the embeddings from word, position and token_type embeddings.
- """
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.vocab_size = config.vocab_size
- self.hidden_size = config.hidden_size
- self.initializer_range = config.initializer_range
- self.position_embeddings = tf.keras.layers.Embedding(
- config.max_position_embeddings,
- config.hidden_size,
- embeddings_initializer=get_initializer(self.initializer_range),
- name="position_embeddings",
- )
- self.token_type_embeddings = tf.keras.layers.Embedding(
- config.type_vocab_size,
- config.hidden_size,
- embeddings_initializer=get_initializer(self.initializer_range),
- name="token_type_embeddings",
- )
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- def build(self, input_shape):
- """Build shared word embedding layer """
- with tf.name_scope("word_embeddings"):
- # Create and initialize weights. The random normal initializer was chosen
- # arbitrarily, and works well.
- self.word_embeddings = self.add_weight(
- "weight",
- shape=[self.vocab_size, self.hidden_size],
- initializer=get_initializer(self.initializer_range),
- )
- super().build(input_shape)
- def call(self, inputs, mode="embedding", training=False):
- """Get token embeddings of inputs.
- Args:
- inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
- mode: string, a valid value is one of "embedding" and "linear".
- Returns:
- outputs: (1) If mode == "embedding", output embedding tensor, float32 with
- shape [batch_size, length, embedding_size]; (2) mode == "linear", output
- linear tensor, float32 with shape [batch_size, length, vocab_size].
- Raises:
- ValueError: if mode is not valid.
- Shared weights logic adapted from
- https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
- """
- if mode == "embedding":
- return self._embedding(inputs, training=training)
- elif mode == "linear":
- return self._linear(inputs)
- else:
- raise ValueError("mode {} is not valid.".format(mode))
- def _embedding(self, inputs, training=False):
- """Applies embedding based on inputs tensor."""
- input_ids, position_ids, token_type_ids, inputs_embeds = inputs
- if input_ids is not None:
- input_shape = shape_list(input_ids)
- else:
- input_shape = shape_list(inputs_embeds)[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
- if token_type_ids is None:
- token_type_ids = tf.fill(input_shape, 0)
- if inputs_embeds is None:
- inputs_embeds = tf.gather(self.word_embeddings, input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
- embeddings = inputs_embeds + position_embeddings + token_type_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings, training=training)
- return embeddings
- def _linear(self, inputs):
- """Computes logits by running inputs through a linear layer.
- Args:
- inputs: A float32 tensor with shape [batch_size, length, hidden_size]
- Returns:
- float32 tensor with shape [batch_size, length, vocab_size].
- """
- batch_size = shape_list(inputs)[0]
- length = shape_list(inputs)[1]
- x = tf.reshape(inputs, [-1, self.hidden_size])
- logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
- return tf.reshape(logits, [batch_size, length, self.vocab_size])
- class TFBertSelfAttention(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- "The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads)
- )
- self.output_attentions = config.output_attentions
- self.num_attention_heads = config.num_attention_heads
- assert config.hidden_size % config.num_attention_heads == 0
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.amp = config.amp
- self.query = tf.keras.layers.Dense(
- self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
- )
- self.key = tf.keras.layers.Dense(
- self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
- )
- self.value = tf.keras.layers.Dense(
- self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
- )
- self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
- def transpose_for_scores(self, x, batch_size):
- x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
- return tf.transpose(x, perm=[0, 2, 1, 3])
- def call(self, inputs, training=False):
- hidden_states, attention_mask, head_mask = inputs
- batch_size = shape_list(hidden_states)[0]
- mixed_query_layer = self.query(hidden_states)
- mixed_key_layer = self.key(hidden_states)
- mixed_value_layer = self.value(hidden_states)
- query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
- key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
- value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = tf.matmul(
- query_layer, key_layer, transpose_b=True
- ) # (batch size, num_heads, seq_len_q, seq_len_k)
- dk = tf.cast(shape_list(key_layer)[-1], tf.float32)
- attention_scores = attention_scores / tf.cast(tf.math.sqrt(dk), tf.float16 if self.amp else tf.float32)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = tf.nn.softmax(attention_scores, axis=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs, training=training)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs = attention_probs * head_mask
- context_layer = tf.matmul(attention_probs, value_layer)
- context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
- context_layer = tf.reshape(
- context_layer, (batch_size, -1, self.all_head_size)
- ) # (batch_size, seq_len_q, all_head_size)
- outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
- return outputs
- class TFBertSelfOutput(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dense = tf.keras.layers.Dense(
- config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- def call(self, inputs, training=False):
- hidden_states, input_tensor = inputs
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states, training=training)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TFBertAttention(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.self_attention = TFBertSelfAttention(config, name="self")
- self.dense_output = TFBertSelfOutput(config, name="output")
- def prune_heads(self, heads):
- raise NotImplementedError
- def call(self, inputs, training=False):
- input_tensor, attention_mask, head_mask = inputs
- self_outputs = self.self_attention([input_tensor, attention_mask, head_mask], training=training)
- attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- class TFBertIntermediate(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dense = tf.keras.layers.Dense(
- config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def call(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class TFBertOutput(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dense = tf.keras.layers.Dense(
- config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- def call(self, inputs, training=False):
- hidden_states, input_tensor = inputs
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states, training=training)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TFBertLayer(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.attention = TFBertAttention(config, name="attention")
- self.intermediate = TFBertIntermediate(config, name="intermediate")
- self.bert_output = TFBertOutput(config, name="output")
- def call(self, inputs, training=False):
- hidden_states, attention_mask, head_mask = inputs
- attention_outputs = self.attention([hidden_states, attention_mask, head_mask], training=training)
- attention_output = attention_outputs[0]
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.bert_output([intermediate_output, attention_output], training=training)
- outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
- return outputs
- class TFBertEncoder(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.output_attentions = config.output_attentions
- self.output_hidden_states = config.output_hidden_states
- self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
- def call(self, inputs, training=False):
- hidden_states, attention_mask, head_mask = inputs
- all_hidden_states = ()
- all_attentions = ()
- for i, layer_module in enumerate(self.layer):
- if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_outputs = layer_module([hidden_states, attention_mask, head_mask[i]], training=training)
- hidden_states = layer_outputs[0]
- if self.output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- # Add last layer
- if self.output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- outputs = (hidden_states,)
- if self.output_hidden_states:
- outputs = outputs + (all_hidden_states,)
- if self.output_attentions:
- outputs = outputs + (all_attentions,)
- return outputs # outputs, (hidden states), (attentions)
- class TFBertPooler(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dense = tf.keras.layers.Dense(
- config.hidden_size,
- kernel_initializer=get_initializer(config.initializer_range),
- activation="tanh",
- name="dense",
- )
- def call(self, hidden_states):
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- return pooled_output
- class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.dense = tf.keras.layers.Dense(
- config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
- )
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
- def call(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- class TFBertLMPredictionHead(tf.keras.layers.Layer):
- def __init__(self, config, input_embeddings, **kwargs):
- super().__init__(**kwargs)
- self.vocab_size = config.vocab_size
- self.transform = TFBertPredictionHeadTransform(config, name="transform")
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.input_embeddings = input_embeddings
- def build(self, input_shape):
- self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias")
- super().build(input_shape)
- def call(self, hidden_states):
- hidden_states = self.transform(hidden_states)
- hidden_states = self.input_embeddings(hidden_states, mode="linear")
- hidden_states = hidden_states + self.bias
- return hidden_states
- class TFBertMLMHead(tf.keras.layers.Layer):
- def __init__(self, config, input_embeddings, **kwargs):
- super().__init__(**kwargs)
- self.predictions = TFBertLMPredictionHead(config, input_embeddings, name="predictions")
- def call(self, sequence_output):
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- class TFBertNSPHead(tf.keras.layers.Layer):
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.seq_relationship = tf.keras.layers.Dense(
- 2, kernel_initializer=get_initializer(config.initializer_range), name="seq_relationship"
- )
- def call(self, pooled_output):
- seq_relationship_score = self.seq_relationship(pooled_output)
- return seq_relationship_score
- @keras_serializable
- class TFBertMainLayer(tf.keras.layers.Layer):
- config_class = BertConfig
- def __init__(self, config, **kwargs):
- super().__init__(**kwargs)
- self.num_hidden_layers = config.num_hidden_layers
- self.embeddings = TFBertEmbeddings(config, name="embeddings")
- self.encoder = TFBertEncoder(config, name="encoder")
- self.pooler = TFBertPooler(config, name="pooler")
- def get_input_embeddings(self):
- return self.embeddings
- def _resize_token_embeddings(self, new_num_tokens):
- raise NotImplementedError
- def _prune_heads(self, heads_to_prune):
- """ Prunes heads of the model.
- heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
- See base class PreTrainedModel
- """
- raise NotImplementedError
- def call(
- self,
- inputs,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- training=False,
- ):
- if isinstance(inputs, (tuple, list)):
- input_ids = inputs[0]
- attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
- token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
- position_ids = inputs[3] if len(inputs) > 3 else position_ids
- head_mask = inputs[4] if len(inputs) > 4 else head_mask
- inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
- assert len(inputs) <= 6, "Too many inputs."
- elif isinstance(inputs, (dict, BatchEncoding)):
- input_ids = inputs.get("input_ids")
- attention_mask = inputs.get("attention_mask", attention_mask)
- token_type_ids = inputs.get("token_type_ids", token_type_ids)
- position_ids = inputs.get("position_ids", position_ids)
- head_mask = inputs.get("head_mask", head_mask)
- inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
- assert len(inputs) <= 6, "Too many inputs."
- else:
- input_ids = inputs
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = shape_list(input_ids)
- elif inputs_embeds is not None:
- input_shape = shape_list(inputs_embeds)[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- if attention_mask is None:
- attention_mask = tf.fill(input_shape, 1)
- if token_type_ids is None:
- token_type_ids = tf.fill(input_shape, 0)
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and -10000.0 for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- if head_mask is not None:
- raise NotImplementedError
- else:
- head_mask = [None] * self.num_hidden_layers
- # head_mask = tf.constant([0] * self.num_hidden_layers)
- embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
- encoder_outputs = self.encoder([embedding_output, extended_attention_mask, head_mask], training=training)
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output)
- outputs = (sequence_output, pooled_output,) + encoder_outputs[
- 1:
- ] # add hidden_states and attentions if they are here
- return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
- class TFBertPreTrainedModel(TFPreTrainedModel):
- """ An abstract class to handle weights initialization and
- a simple interface for downloading and loading pretrained models.
- """
- config_class = BertConfig
- pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
- base_model_prefix = "bert"
- BERT_START_DOCSTRING = r"""
- This model is a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ sub-class.
- Use it as a regular TF 2.0 Keras Model and
- refer to the TF 2.0 documentation for all matter related to general usage and behavior.
- .. note::
- TF 2.0 models accepts two formats as inputs:
- - having all inputs as keyword arguments (like PyTorch models), or
- - having all inputs as a list, tuple or dict in the first positional arguments.
- This second option is useful when using :obj:`tf.keras.Model.fit()` method which currently requires having
- all the tensors in the first argument of the model call function: :obj:`model(inputs)`.
- If you choose this second option, there are three possibilities you can use to gather all the input Tensors
- in the first positional argument :
- - a single Tensor with input_ids only and nothing else: :obj:`model(inputs_ids)`
- - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
- :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
- - a dictionary with one or several input Tensors associated to the input names given in the docstring:
- :obj:`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
- Parameters:
- config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the configuration.
- Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
- """
- BERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using :class:`transformers.BertTokenizer`.
- See :func:`transformers.PreTrainedTokenizer.encode` and
- :func:`transformers.PreTrainedTokenizer.encode_plus` for details.
- `What are input IDs? <../glossary.html#input-ids>`__
- attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
- Mask to avoid performing attention on padding token indices.
- Mask values selected in ``[0, 1]``:
- ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
- `What are attention masks? <../glossary.html#attention-mask>`__
- token_type_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
- Segment token indices to indicate first and second portions of the inputs.
- Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
- corresponds to a `sentence B` token
- `What are token type IDs? <../glossary.html#token-type-ids>`__
- position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
- Indices of positions of each input sequence tokens in the position embeddings.
- Selected in the range ``[0, config.max_position_embeddings - 1]``.
- `What are position IDs? <../glossary.html#position-ids>`__
- head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
- Mask to nullify selected heads of the self-attention modules.
- Mask values selected in ``[0, 1]``:
- :obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
- inputs_embeds (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, embedding_dim)`, `optional`, defaults to :obj:`None`):
- Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
- Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
- (if set to :obj:`False`) for evaluation.
- """
- @add_start_docstrings(
- "The bare Bert Model transformer outputing raw hidden-states without any specific head on top.",
- BERT_START_DOCSTRING,
- )
- class TFBertModel(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Returns:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- pooler_output (:obj:`tf.Tensor` of shape :obj:`(batch_size, hidden_size)`):
- Last layer hidden-state of the first token of the sequence (classification token)
- further processed by a Linear layer and a Tanh activation function. The Linear
- layer weights are trained from the next sentence prediction (classification)
- objective during Bert pretraining. This output is usually *not* a good summary
- of the semantic content of the input, you're often better with averaging or pooling
- the sequence of hidden-states for the whole input sequence.
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertModel
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertModel.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
- """
- outputs = self.bert(inputs, **kwargs)
- return outputs
- @add_start_docstrings(
- """Bert Model with two heads on top as done during the pre-training:
- a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
- BERT_START_DOCSTRING,
- )
- class TFBertForPreTraining(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.nsp = TFBertNSPHead(config, name="nsp___cls")
- self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
- def get_output_embeddings(self):
- return self.bert.embeddings
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- seq_relationship_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`):
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForPreTraining
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForPreTraining.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- prediction_scores, seq_relationship_scores = outputs[:2]
- """
- outputs = self.bert(inputs, **kwargs)
- sequence_output, pooled_output = outputs[:2]
- prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
- seq_relationship_score = self.nsp(pooled_output)
- outputs = (prediction_scores, seq_relationship_score,) + outputs[
- 2:
- ] # add hidden states and attention if they are here
- return outputs # prediction_scores, seq_relationship_score, (hidden_states), (attentions)
- @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
- class TFBertForMaskedLM(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
- def get_output_embeddings(self):
- return self.bert.embeddings
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- prediction_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForMaskedLM
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- prediction_scores = outputs[0]
- """
- outputs = self.bert(inputs, **kwargs)
- sequence_output = outputs[0]
- prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
- outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
- return outputs # prediction_scores, (hidden_states), (attentions)
- @add_start_docstrings(
- """Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
- )
- class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.nsp = TFBertNSPHead(config, name="nsp___cls")
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- seq_relationship_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, 2)`)
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForNextSentencePrediction
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForNextSentencePrediction.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- seq_relationship_scores = outputs[0]
- """
- outputs = self.bert(inputs, **kwargs)
- pooled_output = outputs[1]
- seq_relationship_score = self.nsp(pooled_output)
- outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
- return outputs # seq_relationship_score, (hidden_states), (attentions)
- @add_start_docstrings(
- """Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
- the pooled output) e.g. for GLUE tasks. """,
- BERT_START_DOCSTRING,
- )
- class TFBertForSequenceClassification(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, name="bert")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- self.classifier = tf.keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- logits (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
- Classification (or regression if config.num_labels==1) scores (before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForSequenceClassification
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- logits = outputs[0]
- """
- outputs = self.bert(inputs, **kwargs)
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output, training=kwargs.get("training", False))
- logits = self.classifier(pooled_output)
- outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
- return outputs # logits, (hidden_states), (attentions)
- @add_start_docstrings(
- """Bert Model with a multiple choice classification head on top (a linear layer on top of
- the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
- BERT_START_DOCSTRING,
- )
- class TFBertForMultipleChoice(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.bert = TFBertMainLayer(config, name="bert")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- self.classifier = tf.keras.layers.Dense(
- 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- @property
- def dummy_inputs(self):
- """ Dummy inputs to build the network.
- Returns:
- tf.Tensor with dummy inputs
- """
- return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(
- self,
- inputs,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- training=False,
- ):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- classification_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`:
- `num_choices` is the size of the second dimension of the input tensors. (see `input_ids` above).
- Classification scores (before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForMultipleChoice
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForMultipleChoice.from_pretrained('bert-base-uncased')
- choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
- input_ids = tf.constant([tokenizer.encode(s) for s in choices])[None, :] # Batch size 1, 2 choices
- outputs = model(input_ids)
- classification_scores = outputs[0]
- """
- if isinstance(inputs, (tuple, list)):
- input_ids = inputs[0]
- attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
- token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
- position_ids = inputs[3] if len(inputs) > 3 else position_ids
- head_mask = inputs[4] if len(inputs) > 4 else head_mask
- inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
- assert len(inputs) <= 6, "Too many inputs."
- elif isinstance(inputs, dict):
- input_ids = inputs.get("input_ids")
- attention_mask = inputs.get("attention_mask", attention_mask)
- token_type_ids = inputs.get("token_type_ids", token_type_ids)
- position_ids = inputs.get("position_ids", position_ids)
- head_mask = inputs.get("head_mask", head_mask)
- inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
- assert len(inputs) <= 6, "Too many inputs."
- else:
- input_ids = inputs
- if input_ids is not None:
- num_choices = shape_list(input_ids)[1]
- seq_length = shape_list(input_ids)[2]
- else:
- num_choices = shape_list(inputs_embeds)[1]
- seq_length = shape_list(inputs_embeds)[2]
- flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
- flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
- flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
- flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
- flat_inputs = [
- flat_input_ids,
- flat_attention_mask,
- flat_token_type_ids,
- flat_position_ids,
- head_mask,
- inputs_embeds,
- ]
- outputs = self.bert(flat_inputs, training=training)
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output, training=training)
- logits = self.classifier(pooled_output)
- reshaped_logits = tf.reshape(logits, (-1, num_choices))
- outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
- return outputs # reshaped_logits, (hidden_states), (attentions)
- @add_start_docstrings(
- """Bert Model with a token classification head on top (a linear layer on top of
- the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
- BERT_START_DOCSTRING,
- )
- class TFBertForTokenClassification(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, name="bert")
- self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
- self.classifier = tf.keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
- )
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`):
- Classification scores (before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForTokenClassification
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForTokenClassification.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- scores = outputs[0]
- """
- outputs = self.bert(inputs, **kwargs)
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output, training=kwargs.get("training", False))
- logits = self.classifier(sequence_output)
- outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
- return outputs # scores, (hidden_states), (attentions)
- @add_start_docstrings(
- """Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
- the hidden-states output to compute `span start logits` and `span end logits`). """,
- BERT_START_DOCSTRING,
- )
- class TFBertForQuestionAnswering(TFBertPreTrainedModel):
- def __init__(self, config, *inputs, **kwargs):
- super().__init__(config, *inputs, **kwargs)
- self.num_labels = config.num_labels
- self.bert = TFBertMainLayer(config, name="bert")
- self.qa_outputs = tf.keras.layers.Dense(
- config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs"
- )
- @add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
- def call(self, inputs, **kwargs):
- r"""
- Return:
- :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
- start_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
- Span-start scores (before SoftMax).
- end_scores (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
- Span-end scores (before SoftMax).
- hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
- tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_attentions=True``):
- tuple of :obj:`tf.Tensor` (one for each layer) of shape
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`:
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
- Examples::
- import tensorflow as tf
- from transformers import BertTokenizer, TFBertForQuestionAnswering
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
- model = TFBertForQuestionAnswering.from_pretrained('bert-base-uncased')
- input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True))[None, :] # Batch size 1
- outputs = model(input_ids)
- start_scores, end_scores = outputs[:2]
- """
- outputs = self.bert(inputs, **kwargs)
- sequence_output = outputs[0]
- logits = self.qa_outputs(sequence_output)
- start_logits, end_logits = tf.split(logits, 2, axis=-1)
- start_logits = tf.squeeze(start_logits, axis=-1)
- end_logits = tf.squeeze(end_logits, axis=-1)
- outputs = (start_logits, end_logits,) + outputs[2:]
- return outputs # start_logits, end_logits, (hidden_states), (attentions)
|