inference.sh 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. #!/usr/bin/env bash
  2. # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. set -e
  16. : ${DATASET_DIR:="/datasets/LibriSpeech"}
  17. : ${VALID_SUBSET:="test-other"}
  18. : ${OUTPUT_DIR:="results/inference"}
  19. : ${NUM_GPUS:=1}
  20. : ${BATCH_SIZE:=8}
  21. : ${AMP:=false}
  22. : ${BF16:=false}
  23. : ${FP16:=false}
  24. : ${EMA:=0.0}
  25. : ${SEED:=1}
  26. : ${FINETUNED_MODEL:=results/finetune_base_960h/wav2vec2_update320000.pt}
  27. : ${MASK_PROB:=0.5}
  28. : ${MASK_CHANNEL_PROB:=0.25}
  29. : ${DISTRIBUTED:="-m torch.distributed.launch --nproc_per_node=$NUM_GPUS"}
  30. # inference
  31. : ${MAX_DURATION:=""}
  32. : ${NUM_STEPS:=0}
  33. : ${NUM_WARMUP_STEPS:=0}
  34. : ${CPU:=false}
  35. : ${LOGITS_FILE:=}
  36. : ${PREDICTION_FILE:="${OUTPUT_DIR}/${DATASET}.predictions"}
  37. : ${TORCHSCRIPT:=false}
  38. : ${TORCHSCRIPT_SAVE:=false}
  39. : ${LOG_FILE:=$OUTPUT_DIR/nvlog.json}
  40. mkdir -p "$OUTPUT_DIR"
  41. ARGS+=" --w2v_path $FINETUNED_MODEL"
  42. ARGS+=" --data $DATASET_DIR"
  43. ARGS+=" --valid_subset $VALID_SUBSET"
  44. ARGS+=" --output_dir $OUTPUT_DIR"
  45. ARGS+=" --ema $EMA"
  46. ARGS+=" --seed $SEED"
  47. ARGS+=" --skip_invalid_size_inputs_valid_test"
  48. ARGS+=" --apply_mask"
  49. ARGS+=" --mask_prob $MASK_PROB"
  50. ARGS+=" --mask_channel_prob $MASK_CHANNEL_PROB"
  51. ARGS+=" --mask_channel_length 64"
  52. ARGS+=" --encoder_layerdrop 0.1" # NOTE This is called `layerdrop` in fairseq finetuning yamls
  53. ARGS+=" --activation_dropout 0.1"
  54. ARGS+=" --feature_grad_mult 0.0"
  55. ARGS+=" --batch_size=$BATCH_SIZE"
  56. ARGS+=" --steps $NUM_STEPS"
  57. ARGS+=" --warmup_steps $NUM_WARMUP_STEPS"
  58. [ "$AMP" = true ] && ARGS+=" --amp --fp16"
  59. [ "$BF16" = true ] && ARGS+=" --bf16"
  60. [ "$TORCHSCRIPT" = true ] && ARGS+=" --torchscript"
  61. [ "$TORCHSCRIPT_SAVE" = true ] && ARGS+=" --torchscript_export"
  62. [ -n "$LOG_FILE" ] && ARGS+=" --log_file $LOG_FILE"
  63. [ "$CPU" == "true" ] && ARGS+=" --cpu"
  64. [ -n "$MAX_DURATION" ] && ARGS+=" --max_duration ${MAX_DURATION}"
  65. set -x
  66. if [ $NUM_GPUS -gt 1 ]; then
  67. python3 -m torch.distributed.launch --nproc_per_node=$NUM_GPUS inference.py $ARGS $@
  68. else
  69. python3 inference.py $ARGS $@
  70. fi