mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-28 20:25:20 +00:00
tts: add speaker file support (#12048)
* tts: add speaker file support Signed-off-by: dm4 <sunrisedm4@gmail.com> * tts: handle outetts-0.3 * tts : add new line in error message --------- Signed-off-by: dm4 <sunrisedm4@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@ -2452,6 +2452,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
params.vocoder.use_guide_tokens = true;
|
params.vocoder.use_guide_tokens = true;
|
||||||
}
|
}
|
||||||
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));
|
||||||
|
add_opt(common_arg(
|
||||||
|
{"--tts-speaker-file"}, "FNAME",
|
||||||
|
"speaker file path for audio generation",
|
||||||
|
[](common_params & params, const std::string & value) {
|
||||||
|
params.vocoder.speaker_file = value;
|
||||||
|
}
|
||||||
|
).set_examples({LLAMA_EXAMPLE_TTS}));
|
||||||
|
|
||||||
// model-specific
|
// model-specific
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
@ -200,6 +200,8 @@ struct common_params_vocoder {
|
|||||||
std::string model = ""; // model path // NOLINT
|
std::string model = ""; // model path // NOLINT
|
||||||
std::string model_url = ""; // model url to download // NOLINT
|
std::string model_url = ""; // model url to download // NOLINT
|
||||||
|
|
||||||
|
std::string speaker_file = ""; // speaker file path // NOLINT
|
||||||
|
|
||||||
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy // NOLINT
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||||
|
|
||||||
@ -16,6 +17,13 @@
|
|||||||
#include <thread>
|
#include <thread>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
enum outetts_version {
|
||||||
|
OUTETTS_V0_2,
|
||||||
|
OUTETTS_V0_3,
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// Terminal utils
|
// Terminal utils
|
||||||
//
|
//
|
||||||
@ -371,7 +379,7 @@ static std::string replace_numbers_with_words(const std::string & input_text) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
|
||||||
static std::string process_text(const std::string & text) {
|
static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
|
||||||
// For now I skipped text romanization as I am unsure how to handle
|
// For now I skipped text romanization as I am unsure how to handle
|
||||||
// uroman and MeCab implementations in C++
|
// uroman and MeCab implementations in C++
|
||||||
@ -401,7 +409,8 @@ static std::string process_text(const std::string & text) {
|
|||||||
if (c == ' ') {
|
if (c == ' ') {
|
||||||
prompt_clean += "<|text_sep|>";
|
prompt_clean += "<|text_sep|>";
|
||||||
*/
|
*/
|
||||||
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), "<|text_sep|>");
|
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
|
||||||
|
processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
|
||||||
|
|
||||||
return processed_text;
|
return processed_text;
|
||||||
}
|
}
|
||||||
@ -425,8 +434,8 @@ static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
|
|||||||
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str) {
|
static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
const std::string& delimiter = "<|text_sep|>";
|
const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
@ -452,6 +461,78 @@ static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab,
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static json speaker_from_file(const std::string & speaker_file) {
|
||||||
|
std::ifstream file(speaker_file);
|
||||||
|
if (!file) {
|
||||||
|
LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str());
|
||||||
|
return json();
|
||||||
|
}
|
||||||
|
|
||||||
|
json speaker = json::parse(file);
|
||||||
|
return speaker;
|
||||||
|
}
|
||||||
|
|
||||||
|
static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
|
||||||
|
if (speaker.contains("version")) {
|
||||||
|
std::string version = speaker["version"].get<std::string>();
|
||||||
|
if (version == "0.2") {
|
||||||
|
return OUTETTS_V0_2;
|
||||||
|
} else if (version == "0.3") {
|
||||||
|
return OUTETTS_V0_3;
|
||||||
|
} else {
|
||||||
|
LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also could get version from model itself
|
||||||
|
const char *chat_template = llama_model_chat_template(model, nullptr);
|
||||||
|
if (chat_template && std::string(chat_template) == "outetts-0.3") {
|
||||||
|
return OUTETTS_V0_3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use 0.2 as the default version
|
||||||
|
return OUTETTS_V0_2;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
std::string audio_text = "<|text_start|>";
|
||||||
|
|
||||||
|
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
|
||||||
|
std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
|
||||||
|
for (const auto &word : speaker["words"]) {
|
||||||
|
audio_text += word["word"].get<std::string>() + separator;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return audio_text;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
|
||||||
|
std::string audio_data = "<|audio_start|>\n";
|
||||||
|
|
||||||
|
if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
|
||||||
|
std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
|
||||||
|
std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
|
||||||
|
for (const auto &word : speaker["words"]) {
|
||||||
|
std::string word_text = word["word"].get<std::string>();
|
||||||
|
double duration = word["duration"].get<double>();
|
||||||
|
std::vector<int> codes = word["codes"].get<std::vector<int>>();
|
||||||
|
|
||||||
|
// Create the audio output entry
|
||||||
|
std::ostringstream word_entry;
|
||||||
|
word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
|
||||||
|
<< duration << "|>" + code_start;
|
||||||
|
for (const auto &Code : codes) {
|
||||||
|
word_entry << "<|" << Code << "|>";
|
||||||
|
}
|
||||||
|
word_entry << code_end << "\n";
|
||||||
|
audio_data += word_entry.str();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return audio_data;
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
common_params params;
|
common_params params;
|
||||||
|
|
||||||
@ -523,34 +604,9 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> codes;
|
std::vector<llama_token> codes;
|
||||||
std::vector<llama_token> guide_tokens;
|
std::vector<llama_token> guide_tokens;
|
||||||
|
|
||||||
// process prompt and generate voice codes
|
// the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json
|
||||||
{
|
std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>";
|
||||||
LOG_INF("%s: constructing prompt ..\n", __func__);
|
std::string audio_data = R"(<|audio_start|>
|
||||||
|
|
||||||
std::vector<llama_token> prompt_inp;
|
|
||||||
|
|
||||||
prompt_init(prompt_inp, vocab);
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
|
|
||||||
|
|
||||||
// convert the input text into the necessary format expected by OuteTTS
|
|
||||||
{
|
|
||||||
std::string prompt_clean = process_text(params.prompt);
|
|
||||||
if (params.vocoder.use_guide_tokens) {
|
|
||||||
guide_tokens = prepare_guide_tokens(vocab, prompt_clean);
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
|
|
||||||
|
|
||||||
// disabled to save time on tokenizing each time
|
|
||||||
// TODO: load voices from the json files
|
|
||||||
#if 0
|
|
||||||
const std::string voice_data = R"(<|audio_start|>
|
|
||||||
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
|
||||||
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
|
||||||
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
|
||||||
@ -582,12 +638,64 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
|
|||||||
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
|
||||||
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
|
||||||
|
|
||||||
|
// audio data for 0.3 version
|
||||||
|
outetts_version tts_version = get_tts_version(model_ttc);
|
||||||
|
if (tts_version == OUTETTS_V0_3) {
|
||||||
|
audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
|
||||||
|
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
|
||||||
|
audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
|
||||||
|
}
|
||||||
|
|
||||||
|
// load speaker if given
|
||||||
|
if (!params.vocoder.speaker_file.empty()) {
|
||||||
|
LOG_INF("%s: loading speaker ..\n", __func__);
|
||||||
|
json speaker = speaker_from_file(params.vocoder.speaker_file);
|
||||||
|
if (speaker.empty()) {
|
||||||
|
LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
audio_text = audio_text_from_speaker(speaker, tts_version);
|
||||||
|
audio_data = audio_data_from_speaker(speaker, tts_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
// process prompt and generate voice codes
|
||||||
|
{
|
||||||
|
LOG_INF("%s: constructing prompt ..\n", __func__);
|
||||||
|
|
||||||
|
std::vector<llama_token> prompt_inp;
|
||||||
|
|
||||||
|
prompt_init(prompt_inp, vocab);
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, audio_text, false, true);
|
||||||
|
|
||||||
|
// convert the input text into the necessary format expected by OuteTTS
|
||||||
|
{
|
||||||
|
std::string prompt_clean = process_text(params.prompt, tts_version);
|
||||||
|
if (params.vocoder.use_guide_tokens) {
|
||||||
|
guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, prompt_clean, false, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
|
||||||
|
|
||||||
|
if (!params.vocoder.speaker_file.empty()) {
|
||||||
|
prompt_add(prompt_inp, vocab, audio_data, false, true);
|
||||||
|
} else {
|
||||||
|
// disabled to save time on tokenizing each time
|
||||||
|
#if 1
|
||||||
|
const std::string voice_data = audio_data;
|
||||||
|
|
||||||
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
auto tmp = common_tokenize(vocab, voice_data, false, true);
|
||||||
printf("\n\n");
|
printf("\n\n");
|
||||||
for (int i = 0; i < tmp.size(); ++i) {
|
for (size_t i = 0; i < tmp.size(); ++i) {
|
||||||
printf("%d, ", tmp[i]);
|
printf("%d, ", tmp[i]);
|
||||||
}
|
}
|
||||||
printf("\n\n");
|
printf("\n\n");
|
||||||
|
prompt_add(prompt_inp, tmp);
|
||||||
#else
|
#else
|
||||||
prompt_add(prompt_inp, llama_tokens {
|
prompt_add(prompt_inp, llama_tokens {
|
||||||
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
|
||||||
@ -693,6 +801,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
|
|||||||
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
|
152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
|
||||||
151670,});
|
151670,});
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// print the prompt token-by-token
|
// print the prompt token-by-token
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user