tts: add speaker file support (#12048)

* tts: add speaker file support

Signed-off-by: dm4 <sunrisedm4@gmail.com>

* tts: handle outetts-0.3

* tts : add new line in error message

---------

Signed-off-by: dm4 <sunrisedm4@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
dm4
2025-03-03 21:09:29 +08:00
committed by GitHub
parent d5c63cd7f9
commit c43af9276b
3 changed files with 258 additions and 140 deletions

View File

@ -2452,6 +2452,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.vocoder.use_guide_tokens = true; params.vocoder.use_guide_tokens = true;
} }
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER})); ).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--tts-speaker-file"}, "FNAME",
"speaker file path for audio generation",
[](common_params & params, const std::string & value) {
params.vocoder.speaker_file = value;
}
).set_examples({LLAMA_EXAMPLE_TTS}));
// model-specific // model-specific
add_opt(common_arg( add_opt(common_arg(

View File

@ -200,6 +200,8 @@ struct common_params_vocoder {
std::string model = ""; // model path // NOLINT std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT std::string model_url = ""; // model url to download // NOLINT
std::string speaker_file = ""; // speaker file path // NOLINT
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
}; };

View File

@ -3,6 +3,7 @@
#include "sampling.h" #include "sampling.h"
#include "log.h" #include "log.h"
#include "llama.h" #include "llama.h"
#include "json.hpp"
#define _USE_MATH_DEFINES // For M_PI on MSVC #define _USE_MATH_DEFINES // For M_PI on MSVC
@ -16,6 +17,13 @@
#include <thread> #include <thread>
#include <vector> #include <vector>
using json = nlohmann::ordered_json;
enum outetts_version {
OUTETTS_V0_2,
OUTETTS_V0_3,
};
// //
// Terminal utils // Terminal utils
// //
@ -371,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
} }
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39 // Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
static std::string process_text(const std::string & text) { static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
// For now I skipped text romanization as I am unsure how to handle // For now I skipped text romanization as I am unsure how to handle
// uroman and MeCab implementations in C++ // uroman and MeCab implementations in C++
@ -401,7 +409,8 @@ static std::string process_text(const std::string & text) {
if (c == ' ') { if (c == ' ') {
prompt_clean += "<|text_sep|>"; prompt_clean += "<|text_sep|>";
*/ */
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>"); std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
return processed_text; return processed_text;
} }
@ -425,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt_add(prompt, vocab, "<|im_start|>\n", true, true); prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
} }
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) { static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
const std::string& delimiter = "<|text_sep|>"; const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
std::vector<llama_token> result; std::vector<llama_token> result;
size_t start = 0; size_t start = 0;
@ -452,6 +461,78 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab,
return result; return result;
} }
static json speaker_from_file(const std::string & speaker_file) {
std::ifstream file(speaker_file);
if (!file) {
LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str());
return json();
}
json speaker = json::parse(file);
return speaker;
}
static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
if (speaker.contains("version")) {
std::string version = speaker["version"].get<std::string>();
if (version == "0.2") {
return OUTETTS_V0_2;
} else if (version == "0.3") {
return OUTETTS_V0_3;
} else {
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
}
}
// Also could get version from model itself
const char *chat_template = llama_model_chat_template(model, nullptr);
if (chat_template && std::string(chat_template) == "outetts-0.3") {
return OUTETTS_V0_3;
}
// Use 0.2 as the default version
return OUTETTS_V0_2;
}
static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
std::string audio_text = "<|text_start|>";
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
for (const auto &word : speaker["words"]) {
audio_text += word["word"].get<std::string>() + separator;
}
}
return audio_text;
}
static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
std::string audio_data = "<|audio_start|>\n";
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
for (const auto &word : speaker["words"]) {
std::string word_text = word["word"].get<std::string>();
double duration = word["duration"].get<double>();
std::vector<int> codes = word["codes"].get<std::vector<int>>();
// Create the audio output entry
std::ostringstream word_entry;
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
<< duration << "|>" + code_start;
for (const auto &Code : codes) {
word_entry << "<|" << Code << "|>";
}
word_entry << code_end << "\n";
audio_data += word_entry.str();
}
}
return audio_data;
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
common_params params; common_params params;
@ -523,34 +604,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> codes; std::vector<llama_token> codes;
std::vector<llama_token> guide_tokens; std::vector<llama_token> guide_tokens;
// process prompt and generate voice codes // the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json
{ std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>";
LOG_INF("%s: constructing prompt ..\n", __func__); std::string audio_data = R"(<|audio_start|>
std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
}
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
// disabled to save time on tokenizing each time
// TODO: load voices from the json files
#if 0
const std::string voice_data = R"(<|audio_start|>
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
@ -582,12 +638,64 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
// audio data for 0.3 version
outetts_version tts_version = get_tts_version(model_ttc);
if (tts_version == OUTETTS_V0_3) {
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
}
// load speaker if given
if (!params.vocoder.speaker_file.empty()) {
LOG_INF("%s: loading speaker ..\n", __func__);
json speaker = speaker_from_file(params.vocoder.speaker_file);
if (speaker.empty()) {
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
return 1;
}
audio_text = audio_text_from_speaker(speaker, tts_version);
audio_data = audio_data_from_speaker(speaker, tts_version);
}
// process prompt and generate voice codes
{
LOG_INF("%s: constructing prompt ..\n", __func__);
std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, vocab, audio_text, false, true);
// convert the input text into the necessary format expected by OuteTTS
{
std::string prompt_clean = process_text(params.prompt, tts_version);
if (params.vocoder.use_guide_tokens) {
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
}
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
}
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
if (!params.vocoder.speaker_file.empty()) {
prompt_add(prompt_inp, vocab, audio_data, false, true);
} else {
// disabled to save time on tokenizing each time
#if 1
const std::string voice_data = audio_data;
auto tmp = common_tokenize(vocab, voice_data, false, true); auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n"); printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) { for (size_t i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]); printf("%d, ", tmp[i]);
} }
printf("\n\n"); printf("\n\n");
prompt_add(prompt_inp, tmp);
#else #else
prompt_add(prompt_inp, llama_tokens { prompt_add(prompt_inp, llama_tokens {
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
@ -693,6 +801,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349, 152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
151670,}); 151670,});
#endif #endif
}
// print the prompt token-by-token // print the prompt token-by-token