mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-06-30 12:55:17 +00:00
Merge branch 'master' into gg/llama-kv-cache
ggml-ci
This commit is contained in:
@ -24,7 +24,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "chat-template.hpp"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "json.hpp"
|
||||
#include "linenoise.cpp/linenoise.h"
|
||||
@ -113,6 +113,7 @@ class Opt {
|
||||
llama_context_params ctx_params;
|
||||
llama_model_params model_params;
|
||||
std::string model_;
|
||||
std::string chat_template_file;
|
||||
std::string user;
|
||||
bool use_jinja = false;
|
||||
int context_size = -1, ngl = -1;
|
||||
@ -148,6 +149,16 @@ class Opt {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) {
|
||||
if (i + 1 >= argc) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
option_value = argv[++i];
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int parse(int argc, const char ** argv) {
|
||||
bool options_parsing = true;
|
||||
for (int i = 1, positional_args_i = 0; i < argc; ++i) {
|
||||
@ -169,6 +180,11 @@ class Opt {
|
||||
verbose = true;
|
||||
} else if (options_parsing && strcmp(argv[i], "--jinja") == 0) {
|
||||
use_jinja = true;
|
||||
} else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0){
|
||||
if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) {
|
||||
return 1;
|
||||
}
|
||||
use_jinja = true;
|
||||
} else if (options_parsing && parse_flag(argv, i, "-h", "--help")) {
|
||||
help = true;
|
||||
return 0;
|
||||
@ -207,6 +223,11 @@ class Opt {
|
||||
"Options:\n"
|
||||
" -c, --context-size <value>\n"
|
||||
" Context size (default: %d)\n"
|
||||
" --chat-template-file <path>\n"
|
||||
" Path to the file containing the chat template to use with the model.\n"
|
||||
" Only supports jinja templates and implicitly sets the --jinja flag.\n"
|
||||
" --jinja\n"
|
||||
" Use jinja templating for the chat template of the model\n"
|
||||
" -n, -ngl, --ngl <value>\n"
|
||||
" Number of GPU layers (default: %d)\n"
|
||||
" --temp <value>\n"
|
||||
@ -261,13 +282,12 @@ static int get_terminal_width() {
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
class File {
|
||||
public:
|
||||
FILE * file = nullptr;
|
||||
|
||||
FILE * open(const std::string & filename, const char * mode) {
|
||||
file = fopen(filename.c_str(), mode);
|
||||
file = ggml_fopen(filename.c_str(), mode);
|
||||
|
||||
return file;
|
||||
}
|
||||
@ -303,6 +323,28 @@ class File {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string read_all(const std::string & filename){
|
||||
open(filename, "r");
|
||||
lock();
|
||||
if (!file) {
|
||||
printe("Error opening file '%s': %s", filename.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
|
||||
fseek(file, 0, SEEK_END);
|
||||
size_t size = ftell(file);
|
||||
fseek(file, 0, SEEK_SET);
|
||||
|
||||
std::string out;
|
||||
out.resize(size);
|
||||
size_t read_size = fread(&out[0], 1, size, file);
|
||||
if (read_size != size) {
|
||||
printe("Error reading file '%s': %s", filename.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
~File() {
|
||||
if (fd >= 0) {
|
||||
# ifdef _WIN32
|
||||
@ -327,6 +369,7 @@ class File {
|
||||
# endif
|
||||
};
|
||||
|
||||
#ifdef LLAMA_USE_CURL
|
||||
class HttpClient {
|
||||
public:
|
||||
int init(const std::string & url, const std::vector<std::string> & headers, const std::string & output_file,
|
||||
@ -557,7 +600,7 @@ class LlamaData {
|
||||
llama_model_ptr model;
|
||||
llama_sampler_ptr sampler;
|
||||
llama_context_ptr context;
|
||||
std::vector<llama_chat_message> messages;
|
||||
std::vector<llama_chat_message> messages; // TODO: switch to common_chat_msg
|
||||
std::list<std::string> msg_strs;
|
||||
std::vector<char> fmtted;
|
||||
|
||||
@ -834,44 +877,23 @@ static void add_message(const char * role, const std::string & text, LlamaData &
|
||||
}
|
||||
|
||||
// Function to apply the chat template and resize `formatted` if needed
|
||||
static int apply_chat_template(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||
if (use_jinja) {
|
||||
json messages = json::array();
|
||||
for (const auto & msg : llama_data.messages) {
|
||||
messages.push_back({
|
||||
{"role", msg.role},
|
||||
{"content", msg.content},
|
||||
});
|
||||
}
|
||||
try {
|
||||
minja::chat_template_inputs tmpl_inputs;
|
||||
tmpl_inputs.messages = messages;
|
||||
tmpl_inputs.add_generation_prompt = append;
|
||||
|
||||
minja::chat_template_options tmpl_opts;
|
||||
tmpl_opts.use_bos_token = false;
|
||||
tmpl_opts.use_eos_token = false;
|
||||
|
||||
auto result = tmpl.apply(tmpl_inputs, tmpl_opts);
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
} catch (const std::exception & e) {
|
||||
printe("failed to render the chat template: %s\n", e.what());
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
int result = llama_chat_apply_template(
|
||||
tmpl.source().c_str(), llama_data.messages.data(), llama_data.messages.size(), append,
|
||||
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
|
||||
if (append && result > static_cast<int>(llama_data.fmtted.size())) {
|
||||
llama_data.fmtted.resize(result);
|
||||
result = llama_chat_apply_template(tmpl.source().c_str(), llama_data.messages.data(),
|
||||
llama_data.messages.size(), append, llama_data.fmtted.data(),
|
||||
llama_data.fmtted.size());
|
||||
static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) {
|
||||
common_chat_templates_inputs inputs;
|
||||
for (const auto & msg : llama_data.messages) {
|
||||
common_chat_msg cmsg;
|
||||
cmsg.role = msg.role;
|
||||
cmsg.content = msg.content;
|
||||
inputs.messages.push_back(cmsg);
|
||||
}
|
||||
inputs.add_generation_prompt = append;
|
||||
inputs.use_jinja = use_jinja;
|
||||
|
||||
return result;
|
||||
auto chat_params = common_chat_templates_apply(tmpls, inputs);
|
||||
// TODO: use other params for tool calls.
|
||||
auto result = chat_params.prompt;
|
||||
llama_data.fmtted.resize(result.size() + 1);
|
||||
memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1);
|
||||
return result.size();
|
||||
}
|
||||
|
||||
// Function to tokenize the prompt
|
||||
@ -1015,8 +1037,8 @@ static int generate_response(LlamaData & llama_data, const std::string & prompt,
|
||||
}
|
||||
|
||||
// Helper function to apply the chat template and handle errors
|
||||
static int apply_chat_template_with_error_handling(const common_chat_template & tmpl, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||
const int new_len = apply_chat_template(tmpl, llama_data, append, use_jinja);
|
||||
static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) {
|
||||
const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja);
|
||||
if (new_len < 0) {
|
||||
printe("failed to apply the chat template\n");
|
||||
return -1;
|
||||
@ -1074,12 +1096,33 @@ static int get_user_input(std::string & user_input, const std::string & user) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Reads a chat template file to be used
|
||||
static std::string read_chat_template_file(const std::string & chat_template_file) {
|
||||
if(chat_template_file.empty()){
|
||||
return "";
|
||||
}
|
||||
|
||||
File file;
|
||||
std::string chat_template = "";
|
||||
chat_template = file.read_all(chat_template_file);
|
||||
if(chat_template.empty()){
|
||||
printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno));
|
||||
return "";
|
||||
}
|
||||
return chat_template;
|
||||
}
|
||||
|
||||
// Main chat loop function
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_jinja) {
|
||||
static int chat_loop(LlamaData & llama_data, const std::string & user, const std::string & chat_template_file, bool use_jinja) {
|
||||
int prev_len = 0;
|
||||
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
|
||||
auto chat_templates = common_chat_templates_from_model(llama_data.model.get(), "");
|
||||
GGML_ASSERT(chat_templates.template_default);
|
||||
|
||||
std::string chat_template = "";
|
||||
if(!chat_template_file.empty()){
|
||||
chat_template = read_chat_template_file(chat_template_file);
|
||||
}
|
||||
auto chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template.empty() ? nullptr : chat_template);
|
||||
|
||||
static const bool stdout_a_terminal = is_stdout_a_terminal();
|
||||
while (true) {
|
||||
// Get user input
|
||||
@ -1090,7 +1133,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
|
||||
|
||||
add_message("user", user.empty() ? user_input : user, llama_data);
|
||||
int new_len;
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, true, new_len, use_jinja) < 0) {
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
@ -1105,7 +1148,7 @@ static int chat_loop(LlamaData & llama_data, const std::string & user, bool use_
|
||||
}
|
||||
|
||||
add_message("assistant", response, llama_data);
|
||||
if (apply_chat_template_with_error_handling(*chat_templates.template_default, llama_data, false, prev_len, use_jinja) < 0) {
|
||||
if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, use_jinja) < 0) {
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
@ -1165,7 +1208,7 @@ int main(int argc, const char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (chat_loop(llama_data, opt.user, opt.use_jinja)) {
|
||||
if (chat_loop(llama_data, opt.user, opt.chat_template_file, opt.use_jinja)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user