FastSpeech2中文语音合成就步解析:TTS数据训练实战篇

  1. 参考github网址:

GitHub - roedoejet/FastSpeech2: An implementation of Microsoft’s “FastSpeech 2: Fast and High-Quality End-to-End Text to Speech”

  1. 数据训练所用python 命令:

python3 train.py -p config/AISHELL3/preprocess.yaml -m config/AISHELL3/model.yaml -t config/AISHELL3/train.yaml

  1. 数据训练代码解析

3.1 代码架构overview:

通过 if __name__ == "__main__"运行整个py文件:

调用 “train.txt"和dataset.py加载数据,

调用utils文件夹下的model.py加载模型,声码器,

调用model文件夹下的loss.py中的FastSpeech2Loss class 设置损失函数,

用前面加载的模型和损失函数开始训练模型,导出结果并记录日志。

3.2 按训练步骤分解代码:

Step 0 : 定义可控训练参数, 调动main函数

if __name__ == "__main__":

    #Define Args
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=0)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args() #args为可控训练参数

    # Read Config
    preprocess_config = yaml.load(
        open(args.preprocess_config, "r"), Loader=yaml.FullLoader
    )
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    #Run _main_ function
    main(args, configs)

Step 1 : 启动main函数,加载可控训练参数

def main(args, configs): 
    print("Prepare training ...")

    #加载可控训练参数
    preprocess_config, model_config, train_config = configs

Step 2 : 从train.txt加载数据,并经由dataset.py和torch里的Dataloader处理

def main(args, configs):

    # Get dataset
    dataset = Dataset(
        "train.txt", preprocess_config, train_config, sort=True, drop_last=True
    ) #从 train.txt 中获取dataset
    batch_size = train_config["optimizer"]["batch_size"]
    group_size = 4  # Set this larger than 1 to enable sorting in Dataset,初始值为4

    assert batch_size * group_size < len(dataset)
    loader = DataLoader(
        dataset,
        batch_size=batch_size * group_size,
        shuffle=True,
        collate_fn=dataset.collate_fn,
    )

Step 3 : 定义模型,声码器,损失函数

def main(args, configs):

    # Prepare model
    model, optimizer = get_model(args, configs, device, train=True) #设置优化器

    # 将模型并行训练并移入计算设备中
    model = nn.DataParallel(model) # Model Has Been Defined

    # 计算模型参数量
    num_param = get_param_num(model) # Number of TTS Parameters: num_param
    print("Number of FastSpeech2 Parameters:", num_param)

    # 设置损失函数
    Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)

    # 加载声码器
    vocoder = get_vocoder(model_config, device)

Step 4 : 加载日志,在"./output/log/AISHELL3"目录建立train, val两个文件夹来记录日志

def main(args, configs):

    # Init logger
    for p in train_config["path"].values():
        os.makedirs(p, exist_ok=True)
    train_log_path = os.path.join(train_config["path"]["log_path"], "train")
    val_log_path = os.path.join(train_config["path"]["log_path"], "val")
    os.makedirs(train_log_path, exist_ok=True)
    os.makedirs(val_log_path, exist_ok=True)
    train_logger = SummaryWriter(train_log_path)
    val_logger = SummaryWriter(val_log_path)

Step 5 : 准备训练,加载可控训练参数

def main(args, configs):

    # Training
    step = args.restore_step + 1
    epoch = 1
    grad_acc_step = train_config["optimizer"]["grad_acc_step"]
    grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"]
    total_step = train_config["step"]["total_step"]
    log_step = train_config["step"]["log_step"]
    save_step = train_config["step"]["save_step"]
    synth_step = train_config["step"]["synth_step"]
    val_step = train_config["step"]["val_step"]

    outer_bar = tqdm(total=total_step, desc="Training", position=0)
    outer_bar.n = args.restore_step
    outer_bar.update()

Step 6 : 准备训练,加载进度条,调动utils文件夹下tools.py中的to_device function来提取数据

    while True:
        inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1)
        for batchs in loader:
            for batch in batchs:
                batch = to_device(batch, device)

Step 7 :开始训练,前向传播,计算损失,反向传播,梯度剪枝,更新模型权重参数

    #Load Data
            for batch in batchs:
                batch = to_device(batch, device)
                
                # Forward
                output = model(*(batch[2:]))

                # Cal Loss
                losses = Loss(batch, output)
                total_loss = losses[0]

                # Backward
                total_loss = total_loss / grad_acc_step
                total_loss.backward()
                if step % grad_acc_step == 0:
                    # Clipping gradients to avoid gradient explosion
                    nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh)

                    # Update weights
                    optimizer.step_and_update_lr()
                    optimizer.zero_grad()

Step 8 : 当训练步数到达预先设定的log_step时,调动utils文件夹下tool.py里的log function,记录loss和step

                if step % log_step == 0:
                    losses = [l.item() for l in losses]
                    message1 = "Step {}/{}, ".format(step, total_step)
                    message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format(
                        *losses
                    )

                    with open(os.path.join(train_log_path, "log.txt"), "a") as f:
                        f.write(message1 + message2 + "\n")

                    outer_bar.write(message1 + message2)

                    log(train_logger, step, losses=losses)

Step 9 : 当训练步数到达预先设定的synth_step时,调动utils文件夹下tool.py里的log function 和 synth_one_sample function(具体用来干什么没看懂)

                if step % synth_step == 0:
                    fig, wav_reconstruction, wav_prediction, tag = synth_one_sample(
                        batch,
                        output,
                        vocoder,
                        model_config,
                        preprocess_config,
                    )
                    log(
                        train_logger,
                        fig=fig,
                        tag="Training/step_{}_{}".format(step, tag),
                    )
                    sampling_rate = preprocess_config["preprocessing"]["audio"][
                        "sampling_rate"
                    ]
                    log(
                        train_logger,
                        audio=wav_reconstruction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_reconstructed".format(step, tag),
                    )
                    log(
                        train_logger,
                        audio=wav_prediction,
                        sampling_rate=sampling_rate,
                        tag="Training/step_{}_{}_synthesized".format(step, tag),
                    )

Step 10 : 当训练步数到达预先设定的val_step时,调动evaluate.py里的evaluate function来进行evaluation,并记录在log/AISHELL3/val/log.txt

                if step % val_step == 0:
                    model.eval()
                    message = evaluate(model, step, configs, val_logger, vocoder)
                    with open(os.path.join(val_log_path, "log.txt"), "a") as f:
                        f.write(message + "\n")
                    outer_bar.write(message)

                    model.train()

Step 11 : 当训练步数到达预先设定的save_step时,保存训练模型

                if step % save_step == 0:
                    torch.save(
                        {
                            "model": model.module.state_dict(),
                            "optimizer": optimizer._optimizer.state_dict(),
                        },
                        os.path.join(
                            train_config["path"]["ckpt_path"],
                            "{}.pth.tar".format(step),
                        ),
                    )

Step 12 : 当训练步数到达预先设定的total_step时,退出训练

                if step == total_step:
                    quit()
                step += 1
                outer_bar.update(1)

            inner_bar.update(1)
        epoch += 1

  1. 数据训练代码的输出

在train_log_path和val_log_path输出日志

在ckpt_path输出训练过程中按照save_step存储的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/779964.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

wordpress企业网站模板免费下载

大气上档次的wordpress企业模板&#xff0c;可以直接免费下载&#xff0c;连注册都不需要&#xff0c;网盘就可以直接下载&#xff0c;是不是嘎嘎给力呢 演示 https://www.jianzhanpress.com/?p5857 下载 链接: https://pan.baidu.com/s/1et7uMYd6--NJEWx-srMG1Q 提取码:…

【Python】已解决:nltk.download(‘stopwords‘) 报错问题

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决&#xff1a;nltk.download(‘stopwords’) 报错问题 一、分析问题背景 在使用Python的自然语言处理库NLTK&#xff08;Natural Language Toolkit&#xff09;时&#xff0c…

《向量数据库指南》——Milvus Cloud检索器增强的深度探讨:句子窗口检索与元数据过滤

检索器增强的深度探讨&#xff1a;句子窗口检索与元数据过滤 在信息爆炸的时代&#xff0c;高效的检索系统成为了连接用户与海量数据的关键桥梁。为了进一步提升检索的准确性和用户满意度&#xff0c;检索器增强技术应运而生&#xff0c;其中句子窗口检索与元数据过滤作为两大…

每日一题~oj(贪心)

对于位置 i来说&#xff0c;如果 不选她&#xff0c;那她的贡献是 vali-1 *2&#xff0c;如果选他 &#xff0c;那么她的贡献是 ai. 每一个数的贡献 是基于前一个数的贡献 来计算的。只要保证这个数的前一个数的贡献是最优的&#xff0c;那么以此类推下去&#xff0c;整体的val…

基于自编码器的时间序列异常检测方法(以传感器数据为例,MATLAB R2021b)

尽管近年来研究者对自编码器及其改进算法进行了深入研究&#xff0c;但现阶段仍存在以下问题亟须解决。 1) 无监督学习模式对特征提取能力的限制与有监督学习相比&#xff0c;无监督学习模式摆脱了对样本标签的依赖、避免了人工标注的困难&#xff0c;但也因此失去了样本标签的…

vue3+vite搭建第一个cesium项目详细步骤及环境配置(附源码)

文章目录 1.创建vuevite项目2.安装 Cesium2.1 安装cesium2.2 安装vite-plugin-cesium插件&#xff08;非必选&#xff09;2.3 新建组件页面map.vue2.4 加载地图 3.完成效果图 1.创建vuevite项目 打开cmd窗口执行以下命令&#xff1a;cesium-vue-app是你的项目名称 npm create…

Zkeys三方登录模块支持QQ、支付宝登录

1&#xff0c;覆盖到根目录&#xff0c;并导入update.sql数据库文件到Zkeys数据库里 2. 后台系统权限管理&#xff0c;配置管理员权限-系统类别-找到云外科技&#xff0c;全部打勾 3&#xff0c;后台系统设置找到云外快捷登录模块填写相应的插件授权配置和登录权限配置&#x…

【wordpress教程】wordpress博客网站添加非法关键词拦截

有的网站经常被恶意搜索&#xff0c;站长们不胜其烦。那我们如何屏蔽恶意搜索关键词呢&#xff1f;下面就随小编一起来解决这个问题吧。 后台设置预览图&#xff1a; 设置教程&#xff1a; 1、把以下代码添加至当前主题的 functions.php 文件中&#xff1a; add_action(admi…

Arcgis Api 三维聚合支持最新版API

Arcgis Api 三维聚合支持最新版API 最近有同学问我Arcgis api 三维聚合&#xff0c;官方还不支持三维聚合API&#xff0c;二维可以。所以依旧是通过GraphicLayers 类来实现&#xff0c;可支持最新Arcgis Api版本 效果图&#xff1a;

简单且循序渐进地查找软件中Bug的实用方法

“Bug”这个词常常让许多开发者感到头疼。即使是经验丰富、技术娴熟的开发人员在开发过程中也难以避免遭遇到 Bug。 软件中的故障会让程序员感到挫败。我相信在你的软件开发生涯中&#xff0c;也曾遇到过一些难以排查的问题。软件中的错误可能会导致项目无法按时交付。因此&…

初识STM32:芯片基本信息

STM32简介 STM32是ST公司基于ARM公司的Cortex-M内核开发的32位微控制器。 ARM公司是全球领先的半导体知识产权&#xff08;IP&#xff09;提供商&#xff0c;全世界超过95%的智能手机和平板电脑都采用ARM架构。 ST公司于1987年由意大利的SGS微电子与法国的Thomson半导体合并…

linux软链接和硬链接的区别

1 创建软链接和硬链接 如下图所示&#xff0c;一开始有两个文件soft和hard。使用 ln -s soft soft1创建软链接&#xff0c;soft1是soft的软链接&#xff1b;使用ln hard hard1创建硬链接&#xff0c;hard1是hard的硬链接。可以看到软链接的文件类型和其它3个文件的文件类型是不…

从“移花接木”到“提质增效”——详解嫁接打印技术

嫁接打印&#xff0c;是融合了3D打印与传统制造精髓的创新技术&#xff0c;其核心在于&#xff0c;通过巧妙地将传统模具加工与先进的3D打印技术相结合&#xff0c;实现了模具制造的“提质、增效、降本”。 嫁接打印的定义 简而言之&#xff0c;嫁接打印是一种增减材混合制造的…

uniapp报错--app.json: 在项目根目录未找到 app.json

【问题】 刚创建好的uni-app项目&#xff0c;运行微信小程序控制台报错如下&#xff1a; 【解决方案】 1. 程序根目录打开project.config.json文件 2. 配置miniprogramRoot&#xff0c;指定小程序代码的根目录 我的小程序代码编译后的工程文件目录为&#xff1a;dist/dev/mp…

阿里云Elasticsearch-趣味体验

阿里云Elasticsearch-趣味体验 什么是阿里云Elasticsearch阿里云Elasticsearch开通服务查看Elasticsearch实例配置Kibana公网IP登录Elasticsearch添加测试数据 Kibana数据分析查看数据字段筛选数据页面条件筛选KQL语法筛选保存搜索语句导出筛选结果指定列表展示字段写在最后 什…

multisim中关于74ls192n和DSWPK开关仿真图分析(减法计数器)

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

DAMA学习笔记(四)-数据建模与设计

1.引言 数据建模是发现、分析和确定数据需求的过程&#xff0c;用一种称为数据模型的精确形式表示和传递这些数据需求。建模过程中要求组织发现并记录数据组合的方式。数据常见的模式: 关系模式、多维模式、面向对象模式、 事实模式、时间序列模式和NoSQL模式。按照描述详细程度…

实现资产优化管理:智慧校园资产分类功能解析

在构建智慧校园的过程中&#xff0c;细致入微的资产管理是确保教育资源高效运作的关键一环&#xff0c;而资产分类功能则扮演着举足轻重的角色。系统通过精心设计的分类体系&#xff0c;将校园内的各类资产&#xff0c;从昂贵的教学设备到日常使用的办公物资&#xff0c;乃至无…

S32DS S32 Design Studio for S32 Platform 3.5 代码显示行号与空白符

介绍 NXP S32DS&#xff0c;全称 S32 Design Studio&#xff0c;s32 系列芯片默认使用 S32 Design Studio for S32 Platform 作为 IDE 集成开发环境&#xff0c;当前版本 S32 Design Studio for S32 Platform 3.5&#xff0c;IDE 可以简称 s32DS 使用 S32DS&#xff0c;可以认…

数据结构算法-排序(一)-冒泡排序

什么是冒泡排序 冒泡排序&#xff1a;在原数组中通过相邻两项元素的比较&#xff0c;交换而完成的排序算法。 算法核心 数组中相邻两项比较、交换。 算法复杂度 时间复杂度 实现一次排序找到最大值需要遍历 n-1次(n为数组长度) 需要这样的排序 n-1次。 需要 (n-1) * (n-1) —…