基本信息
源码名称:深度学习框架
源码大小:12.70M
文件格式:.zip
开发语言:Python
更新时间:2021-11-26
   友情提示:(无需注册或充值,赞助后即可获取资源下载链接)

     嘿,亲!知识可是无价之宝呢,但咱这精心整理的资料也耗费了不少心血呀。小小地破费一下,绝对物超所值哦!如有下载和支付问题,请联系我们QQ(微信同号):813200300

本次赞助数额为: 2 元 
   源码介绍
深度学习一般框架

def main(argv):
  """
    main function
  """
  # pylint: disable=unused-argument

  if FLAGS.config != '':
    config = utils.load_config(FLAGS.config)
    utils.set_logging(FLAGS.log_debug, config)

    utils.copy_config(FLAGS.config, config)
    set_seed(config)
  else:
    config = None

  logging.info("Loading all modules ...")
  import_all_modules_for_register(config, only_nlp=FLAGS.only_nlp)

  logging.info("CMD: {}".format(FLAGS.cmd))
  if FLAGS.cmd == 'train' or FLAGS.cmd == 'train_and_eval' or \
    FLAGS.cmd == 'eval' or FLAGS.cmd == 'infer' or \
    FLAGS.cmd == 'export_model' or FLAGS.cmd == 'gen_feat' or \
    FLAGS.cmd == 'gen_cmvn':
    solver_name = config['solver']['name']
    solver = registers.solver[solver_name](config)
    # config after process
    config = solver.config
    task_name = config['data']['task']['name']
    task_class = registers.task[task_name]
    if FLAGS.cmd == 'train':
      solver.train()
    elif FLAGS.cmd == 'train_and_eval':
      solver.train_and_eval()
    elif FLAGS.cmd == 'eval':
      solver.eval()
    elif FLAGS.cmd == 'infer':
      solver.infer(yield_single_examples=False)
    elif FLAGS.cmd == 'export_model':
      solver.export_model()
    elif FLAGS.cmd == 'gen_feat':
      assert config['data']['task'][
          'suffix'] == '.npy', 'wav does not need to extractor feature'
      paths = []
      for mode in [utils.TRAIN, utils.EVAL, utils.INFER]:
        paths = config['data'][mode]['paths']
      task = task_class(config, utils.INFER)
      task.generate_feat(paths, dry_run=FLAGS.dry_run)
    elif FLAGS.cmd == 'gen_cmvn':
      logging.info(
          '''using infer pipeline to compute cmvn of train_paths, and stride must be 1'''
      )
      paths = config['data'][utils.TRAIN]['paths']
      segments = config['data'][utils.TRAIN]['segments']
      config['data'][utils.INFER]['paths'] = paths
      config['data'][utils.INFER]['segments'] = segments
      task = task_class(config, utils.INFER)
      task.generate_cmvn(dry_run=FLAGS.dry_run)
  elif FLAGS.cmd == 'build':
    build_dataset(FLAGS.name, FLAGS.dir)
  else:
    raise ValueError("Not support command: {}.".format(FLAGS.cmd))


def entry():
  define_flags()
  flags.DEFINE_bool('only_nlp', 'False', 'only use nlp modules')
  logging.info("Deep Language Technology Platform start...")
  app.run(main)
  logging.info("OK. Done!")


def nlp_entry():
  define_flags()
  flags.DEFINE_bool('only_nlp', 'True', 'only use nlp modules')
  logging.info("Deep Language Technology Platform start...")
  app.run(main)
  logging.info("OK. Done!")


if __name__ == '__main__':
  entry()