run_eval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import argparse
  16. import shutil
  17. import datetime
  18. import json
  19. import time
  20. import warnings
  21. from logging import getLogger
  22. from pathlib import Path
  23. from typing import Dict, List
  24. from json import JSONDecodeError
  25. import torch
  26. from torch import nn
  27. from tqdm import tqdm
  28. from torch.utils.data import DataLoader
  29. import numpy as np
  30. import os
  31. import glob
  32. import dllogger
  33. from bart.configuration.configuration_bart import BartConfig
  34. from bart.tokenization.tokenization_bart import BartTokenizer
  35. from bart.modeling.modeling_bart import BartForConditionalGeneration, shift_tokens_right
  36. from utils.utils import (
  37. calculate_bleu,
  38. calculate_rouge,
  39. Seq2SeqDataset,
  40. parse_numeric_n_bool_cl_kwargs,
  41. use_task_specific_params,
  42. encode_line,
  43. load_json,
  44. lmap,
  45. chunks,
  46. write_txt_file,
  47. save_json,
  48. format_step)
  49. import utils.distributed_utils
  50. logger = getLogger(__name__)
  51. DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  52. def distill(layers, num_layers):
  53. sft_layers = nn.ModuleList()
  54. for i in range(num_layers):
  55. sft_layers.append(layers[i])
  56. # delete unnecessary layers
  57. delete_layers = [i for i in range(num_layers, len(layers))]
  58. for i in range(len(delete_layers)):
  59. del layers[delete_layers[i] - i]
  60. return sft_layers
  61. def distill_sft(model, num_layers, do_encoder=False, do_decoder=False):
  62. if do_encoder:
  63. layers = model.model.encoder.layers
  64. sft_layers = distill(layers, num_layers)
  65. model.model.encoder.layers = sft_layers
  66. if do_decoder:
  67. layers = model.model.decoder.layers
  68. sft_layers = distill(layers, num_layers)
  69. model.model.decoder.layers = sft_layers
  70. return model
  71. def generate_summaries_or_translations(
  72. data_dir: str,
  73. out_dir: str,
  74. model_path: str,
  75. config_path: str,
  76. batch_size: int = 8,
  77. device: str = DEFAULT_DEVICE,
  78. fp16=False,
  79. bf16=False,
  80. pre_ln=False,
  81. task="summarization",
  82. prefix=None,
  83. max_source_length=1024,
  84. max_target_length=142,
  85. eval_beams=5,
  86. eval_max_gen_length=142,
  87. n_obs=-1,
  88. type_path="test",
  89. num_return_sequences=1,
  90. distill=None,
  91. num_layers=None,
  92. do_encoder=False,
  93. do_decoder=False,
  94. **generate_kwargs,
  95. ) -> Dict:
  96. out_dir = Path(out_dir)
  97. save_path = out_dir.joinpath(f"rank_{utils.distributed_utils.get_rank()}_output.json")
  98. if num_return_sequences > eval_beams:
  99. eval_beams = num_return_sequences
  100. ### Define BART model
  101. # Config from "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json
  102. # Vocab modified to 50265 to be consistent with facebook/bart-large default
  103. config = BartConfig(**json.load(open(config_path, "r")))
  104. if fp16:
  105. config.dtype = torch.float16
  106. elif bf16:
  107. config.dtype = torch.bfloat16
  108. else:
  109. config.dtype = None
  110. config.pre_ln = pre_ln
  111. model = BartForConditionalGeneration.from_pretrained(model_path, config=config).to(device)
  112. # if distilling, change model
  113. if distill == "sft":
  114. model = distill_sft(model, num_layers, do_encoder, do_decoder)
  115. if fp16:
  116. model = model.half()
  117. elif bf16:
  118. model = model.bfloat16()
  119. model.eval()
  120. tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
  121. logger.info(f"Inferred tokenizer type: {tokenizer.__class__}") # if this is wrong, check config.model_type.
  122. start_time = time.time()
  123. # update config with task specific params
  124. use_task_specific_params(model, task)
  125. if prefix is None:
  126. prefix = prefix or getattr(model.config, "prefix", "") or ""
  127. ds = Seq2SeqDataset(tokenizer, data_dir, max_source_length, max_target_length, type_path=type_path,
  128. n_obs=n_obs, prefix=prefix)
  129. # I set shuffle=True for a more accurate progress bar.
  130. # If all the longest samples are first, the prog bar estimate is too high at the beginning.
  131. is_distributed = True if utils.distributed_utils.get_world_size() > 1 else False
  132. sampler = ds.make_sortish_sampler(batch_size, distributed=is_distributed, add_extra_examples=False, shuffle=True)
  133. data_loader = DataLoader(ds, sampler=sampler, batch_size=batch_size, collate_fn=ds.collate_fn)
  134. results = []
  135. with torch.no_grad():
  136. for batch in tqdm(data_loader):
  137. torch.cuda.synchronize()
  138. t0 = time.time()
  139. summaries = model.generate(
  140. input_ids=batch["input_ids"].to(device),
  141. attention_mask=batch["attention_mask"].to(device),
  142. use_cache=True,
  143. num_return_sequences=num_return_sequences,
  144. num_beams=eval_beams,
  145. max_length=eval_max_gen_length,
  146. num_beam_groups=1, output_scores=False,
  147. return_dict_in_generate=False,
  148. encoder_no_repeat_ngram_size=0,
  149. diversity_penalty=0.0,
  150. **generate_kwargs,
  151. )
  152. preds = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
  153. ids = batch["ids"]
  154. if num_return_sequences > 1:
  155. preds = chunks(preds, num_return_sequences) # batch size chunks, each of size num_return_seq
  156. torch.cuda.synchronize()
  157. eval_time = time.time() - t0
  158. for i, pred in enumerate(preds):
  159. store_time = eval_time if i == 0 else None #only store latency for element 0 of every batch
  160. results.append(dict(pred=pred, id=ids[i].item(), eval_time=store_time))
  161. save_json(results, save_path)
  162. runtime = int(time.time() - start_time) # seconds
  163. num_replicas = sampler.num_replicas if is_distributed else 1
  164. n_obs = len(results)
  165. return results, num_replicas, dict(n_obs=n_obs, eval_only_runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
  166. def datetime_now():
  167. return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  168. def run_generate(verbose=True):
  169. """
  170. Takes input text, generates output, and then using reference calculates the BLEU scores.
  171. The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
  172. Args:
  173. verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
  174. Returns:
  175. a tuple: ``(scores, params}``
  176. - ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
  177. - ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
  178. """
  179. parser = argparse.ArgumentParser()
  180. parser.add_argument("model_path", type=str, help="like facebook/bart-large-cnn or path to ckpt")
  181. parser.add_argument("config_path", type=str, help="path to config")
  182. parser.add_argument("data_dir", type=str, help="like cnn_dm/test.source")
  183. parser.add_argument("save_path", type=str, help="where to save summaries")
  184. parser.add_argument("--type_path", type=str, required=False, default="test", help="like cnn_dm/test.target")
  185. parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
  186. parser.add_argument(
  187. "--prefix", type=str, required=False, default=None, help="will be added to the begininng of src examples"
  188. )
  189. parser.add_argument("--task", type=str, default="summarization", help="used for task_specific_params + metrics")
  190. parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
  191. parser.add_argument(
  192. "--n_obs", type=int, default=None, required=False, help="How many observations. Defaults to all."
  193. )
  194. parser.add_argument(
  195. "--num_return_sequences", type=int, default=1, required=False, help="How many sequences to return"
  196. )
  197. parser.add_argument("--fp16", action="store_true")
  198. parser.add_argument("--bf16", action="store_true")
  199. parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results")
  200. parser.add_argument(
  201. "--info",
  202. nargs="?",
  203. type=str,
  204. const=datetime_now(),
  205. help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
  206. )
  207. parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
  208. parser.add_argument("--eval_beams", type=int, default=None, required=False, help="# beams to use. 0 corresponds to not using beam search.")
  209. parser.add_argument(
  210. "--max_source_length",
  211. default=1024,
  212. type=int,
  213. help="The maximum total input sequence length after tokenization. Sequences longer "
  214. "than this will be truncated, sequences shorter will be padded.",
  215. )
  216. parser.add_argument(
  217. "--max_target_length",
  218. default=142,
  219. type=int,
  220. help="The maximum total input sequence length after tokenization. Sequences longer "
  221. "than this will be truncated, sequences shorter will be padded.",
  222. )
  223. parser.add_argument(
  224. "--sync_timeout",
  225. type=int,
  226. default=600,
  227. required=False,
  228. help="How long should master process wait for other processes to finish.",
  229. )
  230. parser.add_argument("--debug", action="store_true")
  231. parser.add_argument('--json-summary', type=str, default="results/dllogger.json",
  232. help='If provided, the json summary will be written to'
  233. 'the specified file.')
  234. parser.add_argument('--distill', type=str, default=None, help="string indicating how model is distilled, only sft supported", choices=["sft",None])
  235. parser.add_argument('--layers', type=str, default=None, help="string indicating which teacher layers remain, split by '-' (ex. 0-6-11)")
  236. parser.add_argument('--do_encoder', action="store_true", default=False, help="if true encoder distilled")
  237. parser.add_argument('--do_decoder', action="store_true", default=False, help="if true decoder distilled")
  238. parser.add_argument("--pre_ln",
  239. default=False,
  240. action='store_true',
  241. help="Whether to use Pre-LN architecture."
  242. )
  243. dist = parser.add_argument_group('distributed setup')
  244. dist.add_argument('--local_rank', type=int,
  245. default=os.getenv('LOCAL_RANK', 0),
  246. help='Used for multi-process training.')
  247. start_time = time.time()
  248. # Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
  249. args, rest = parser.parse_known_args()
  250. parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
  251. if args.local_rank <= 0:
  252. print(args)
  253. print(rest)
  254. # Initialize device and distributed backend
  255. utils.distributed_utils.init_distributed(args.device == "cuda")
  256. if utils.distributed_utils.get_world_size() > 1:
  257. utils.distributed_utils.set_affinity(args.local_rank)
  258. torch.cuda.set_device(args.local_rank)
  259. if Path(args.json_summary).exists():
  260. warnings.warn(f"json_summary {args.json_summary} will be overwritten unless you type ctrl-c.")
  261. if utils.distributed_utils.get_rank() == 0:
  262. dllogger.init(backends=[dllogger.JSONStreamBackend(verbosity=dllogger.Verbosity.VERBOSE,
  263. filename=args.json_summary),
  264. dllogger.StdOutBackend(verbosity=dllogger.Verbosity.VERBOSE, step_format=format_step)])
  265. else:
  266. dllogger.init(backends=[])
  267. if parsed_args and verbose:
  268. print(f"parsed the following generate kwargs: {parsed_args}")
  269. Path(args.save_path).parent.mkdir(exist_ok=True)
  270. json_save_path = Path(args.save_path + "/tmp")
  271. Path(json_save_path).mkdir(exist_ok=True) # this handles locking.
  272. if args.layers:
  273. num_layers = len(args.layers.split('-'))
  274. else:
  275. num_layers = None
  276. results, num_replicas, runtime_metrics = generate_summaries_or_translations(
  277. args.data_dir,
  278. json_save_path,
  279. args.model_path,
  280. args.config_path,
  281. batch_size=args.bs,
  282. device=args.device,
  283. fp16=args.fp16,
  284. bf16=args.bf16,
  285. pre_ln=args.pre_ln,
  286. task=args.task,
  287. prefix=args.prefix,
  288. eval_beams=args.eval_beams,
  289. max_source_length=args.max_source_length,
  290. max_target_length=args.max_target_length,
  291. eval_max_gen_length=args.eval_max_gen_length,
  292. n_obs=args.n_obs,
  293. type_path=args.type_path,
  294. num_return_sequences=args.num_return_sequences,
  295. distill=args.distill,
  296. num_layers=num_layers,
  297. do_encoder=args.do_encoder,
  298. do_decoder=args.do_decoder,
  299. **parsed_args,
  300. )
  301. if args.local_rank <= 0:
  302. save_path = Path(args.save_path)
  303. save_path.mkdir(exist_ok=True)
  304. partial_results = gather_results_from_each_node(num_replicas, json_save_path, args.sync_timeout)
  305. preds, time_list = combine_partial_results(partial_results)
  306. if args.num_return_sequences > 1:
  307. save_path = save_path.joinpath("pseudolabel_results.json")
  308. print(f"Saving aggregated results at {save_path}, intermediate in {json_save_path}/")
  309. save_json(preds, save_path)
  310. return
  311. tgt_file = Path(args.data_dir).joinpath(args.type_path + ".target")
  312. labels = [x.rstrip() for x in open(tgt_file).readlines()][: len(preds)]
  313. # Calculate metrics, save metrics, and save _generations.txt
  314. calc_bleu = "translation" in args.task
  315. score_fn = calculate_bleu if calc_bleu else calculate_rouge
  316. metric_name = "bleu" if calc_bleu else "rouge"
  317. metrics: Dict = score_fn(preds, labels)
  318. metrics["n_obs"] = len(preds)
  319. runtime = time.time() - start_time
  320. metrics["seconds_per_sample"] = round(runtime / metrics["n_obs"], 4)
  321. metrics["n_gpus"] = num_replicas
  322. metrics.update(runtime_metrics)
  323. time_list.sort()
  324. metrics["inference_latency_mean"] = np.mean(time_list)
  325. metrics["inference_latency_conf_50"] = max(time_list[:int(len(time_list) * 0.50)])
  326. metrics["inference_latency_conf_90"] = max(time_list[:int(len(time_list) * 0.90)])
  327. metrics["inference_latency_conf_95"] = max(time_list[:int(len(time_list) * 0.95)])
  328. metrics["inference_latency_conf_99"] = max(time_list[:int(len(time_list) * 0.99)])
  329. metrics["inference_latency_conf_100"] = max(time_list[:int(len(time_list) * 1)])
  330. metrics["inference_throughput_mean"] = len(preds) * 1.0 / sum(time_list)
  331. metrics_save_path = save_path.joinpath(f"{args.type_path}_{metric_name}.json")
  332. save_json(metrics, metrics_save_path, indent=None)
  333. dllogger.log(step=tuple(), data=metrics)
  334. print(metrics)
  335. write_txt_file(preds, save_path.joinpath(f"{args.type_path}_generations.txt"))
  336. if args.debug:
  337. write_txt_file(labels, save_path.joinpath(f"{args.type_path}.target"))
  338. else:
  339. shutil.rmtree(json_save_path)
  340. dllogger.flush()
  341. def combine_partial_results(partial_results) -> List:
  342. """Concatenate partial results into one file, then sort it by id."""
  343. records = []
  344. for partial_result in partial_results:
  345. records.extend(partial_result)
  346. records = list(sorted(records, key=lambda x: x["id"]))
  347. preds = [x["pred"] for x in records]
  348. eval_time = [x["eval_time"] for x in records if x["eval_time"] is not None]
  349. return preds, eval_time
  350. def gather_results_from_each_node(num_replicas, save_path, timeout) -> List[Dict[str, List]]:
  351. # WAIT FOR lots of .json files
  352. start_wait = time.time()
  353. logger.info("waiting for all nodes to finish")
  354. json_data = None
  355. while (time.time() - start_wait) < timeout:
  356. json_files = list(save_path.glob("rank_*.json"))
  357. if len(json_files) < num_replicas:
  358. continue
  359. try:
  360. # make sure all json files are fully saved
  361. json_data = lmap(load_json, json_files)
  362. return json_data
  363. except JSONDecodeError:
  364. continue
  365. else:
  366. raise TimeoutError("Rank 0 gave up on waiting for other processes")
  367. # Unreachable
  368. if __name__ == "__main__":
  369. # Usage for MT:
  370. # python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_path/test_translations.txt --reference_path $DATA_DIR/test.target --task translation $@
  371. run_generate(verbose=True)