|
|
@@ -32,7 +32,7 @@ import argparse
|
|
|
|
|
|
def parse_args(parser):
|
|
|
"""
|
|
|
- Parse commandline arguments.
|
|
|
+ Parse commandline arguments.
|
|
|
"""
|
|
|
parser.add_argument("--trtis_model_name",
|
|
|
type=str,
|
|
|
@@ -42,8 +42,8 @@ def parse_args(parser):
|
|
|
type=int,
|
|
|
default=1,
|
|
|
help="exports to appropriate directory for TRTIS")
|
|
|
- parser.add_argument('--amp-run', action='store_true',
|
|
|
- help='inference with AMP')
|
|
|
+ parser.add_argument('--fp16', action='store_true',
|
|
|
+ help='inference with mixed precision')
|
|
|
return parser
|
|
|
|
|
|
|
|
|
@@ -52,13 +52,13 @@ def main():
|
|
|
description='PyTorch WaveGlow TRTIS config exporter')
|
|
|
parser = parse_args(parser)
|
|
|
args = parser.parse_args()
|
|
|
-
|
|
|
+
|
|
|
# prepare repository
|
|
|
model_folder = os.path.join('./trtis_repo', args.trtis_model_name)
|
|
|
version_folder = os.path.join(model_folder, str(args.trtis_model_version))
|
|
|
if not os.path.exists(version_folder):
|
|
|
os.makedirs(version_folder)
|
|
|
-
|
|
|
+
|
|
|
# build the config for TRTIS
|
|
|
config_filename = os.path.join(model_folder, "config.pbtxt")
|
|
|
config_template = r"""
|
|
|
@@ -84,12 +84,12 @@ output {{
|
|
|
dims: [-1]
|
|
|
}}
|
|
|
"""
|
|
|
-
|
|
|
+
|
|
|
config_values = {
|
|
|
"model_name": args.trtis_model_name,
|
|
|
- "fp_type": "TYPE_FP16" if args.amp_run else "TYPE_FP32"
|
|
|
+ "fp_type": "TYPE_FP16" if args.fp16 else "TYPE_FP32"
|
|
|
}
|
|
|
-
|
|
|
+
|
|
|
with open(model_folder + "/config.pbtxt", "w") as file:
|
|
|
final_config_str = config_template.format_map(config_values)
|
|
|
file.write(final_config_str)
|
|
|
@@ -97,4 +97,3 @@ output {{
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
main()
|
|
|
-
|