mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-08-16 05:02:58 -04:00
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs) * Vulkan: Implement GGML_OP_OPT_STEP_SGD * tests: Fix OPT_STEP_SGD test-backend-ops * SGD op param store weight-decay and not 1-alpha*wd * minor + cosmetic changes * fix vulkan sgd * try CI fix --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
// TODO refactor
|
||||
|
||||
#include "ggml.h"
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-opt.h"
|
||||
#include "../ggml/src/ggml-impl.h"
|
||||
#include "../common/common.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cinttypes>
|
||||
@@ -11,6 +15,8 @@
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
#define TEST_LOG(...) GGML_LOG_DEBUG(__VA_ARGS__)
|
||||
|
||||
static bool almost_equal(const double a, const double b, const double atol) {
|
||||
return fabs(a - b) < atol;
|
||||
}
|
||||
@@ -40,14 +46,20 @@ struct helper_ctx_data {
|
||||
// These default values make it easier to check optimization results vs. expected values.
|
||||
static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) {
|
||||
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
|
||||
|
||||
result.adamw.alpha = 1.0f;
|
||||
result.adamw.beta1 = 0.0f;
|
||||
result.adamw.beta2 = 0.0f;
|
||||
result.adamw.eps = 0.0f;
|
||||
result.adamw.wd = 0.0f;
|
||||
result.sgd.wd = 0.0f;
|
||||
result.sgd.alpha = 1.0f;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static helper_ctx_data helper_get_ctx_data(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched,
|
||||
ggml_backend_t backend,
|
||||
const bool init_opt_ctx = true,
|
||||
@@ -134,10 +146,13 @@ static helper_ctx_data helper_get_ctx_data(
|
||||
opt_params.inputs = inputs;
|
||||
opt_params.outputs = outputs;
|
||||
opt_params.opt_period = opt_period;
|
||||
opt_params.optimizer = optim;
|
||||
if (!optimizer_defaults) {
|
||||
opt_params.get_opt_pars = helper_get_test_opt_pars;
|
||||
}
|
||||
GGML_ASSERT(opt_params.get_opt_pars);
|
||||
ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr;
|
||||
GGML_ASSERT(!opt_ctx || ggml_opt_context_optimizer_type(opt_ctx) == opt_params.optimizer);
|
||||
|
||||
ggml_opt_result_t result = ggml_opt_result_init();
|
||||
ggml_opt_result_t result2 = ggml_opt_result_init();
|
||||
@@ -158,25 +173,37 @@ static void helper_free_ctx_data(struct helper_ctx_data ctx_data) {
|
||||
ggml_opt_dataset_free(ctx_data.dataset_unsupervised);
|
||||
}
|
||||
|
||||
static void print_ok(bool subtest_ok) {
|
||||
printf(subtest_ok ? "\033[1;32mOK\033[0m\n" : "\033[1;31mFAIL\033[0m\n");
|
||||
}
|
||||
|
||||
static void helper_after_test(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
const char * func, const bool high_level, const std::string options,
|
||||
const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
|
||||
printf(" %s(high_level=%s%s, subtest=%s): ",
|
||||
func, high_level ? "yes" : "no", options.c_str(), subtest.c_str());
|
||||
if (subtest_ok) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
printf(" %s(high_level=%s%s, subtest=%s, optimizer=%s): ",
|
||||
func, high_level ? "yes" : "no", options.c_str(), subtest.c_str(), ggml_opt_optimizer_name(optim));
|
||||
print_ok(subtest_ok);
|
||||
if (subtest_ok)
|
||||
npass++;
|
||||
} else {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
}
|
||||
ntest++;
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
|
||||
static void print_ok(const char * func, bool subtest_ok, int & npass, int & ntest, const char * args = "") {
|
||||
printf(" %s(%s): ", func, args);
|
||||
print_ok(subtest_ok);
|
||||
if (subtest_ok)
|
||||
npass++;
|
||||
++ntest;
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_dataset(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend);
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend);
|
||||
|
||||
for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) {
|
||||
ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1];
|
||||
@@ -255,11 +282,13 @@ static std::pair<int, int> test_dataset(ggml_backend_sched_t backend_sched, ggml
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
static std::pair<int, int> test_grad(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false,
|
||||
/*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1);
|
||||
|
||||
std::vector<float> grad_history(ndata);
|
||||
@@ -270,6 +299,7 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
|
||||
for (int idata = 0; idata < ndata; ++idata) {
|
||||
const float idataf = idata;
|
||||
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
|
||||
// leaked
|
||||
ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs));
|
||||
ggml_opt_eval(cd.opt_ctx, cd.result);
|
||||
ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float));
|
||||
@@ -298,19 +328,21 @@ static std::pair<int, int> test_grad(ggml_backend_sched_t backend_sched, ggml_ba
|
||||
}
|
||||
|
||||
static void helper_after_test_forward_backward(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
const char * func, const bool high_level, const bool shuffle,
|
||||
const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
|
||||
std::string options = ", shuffle=";
|
||||
options += shuffle ? "yes" : "no";
|
||||
helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
|
||||
helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_forward_backward(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
|
||||
struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
|
||||
|
||||
std::vector<float> loss_history(ndata);
|
||||
@@ -328,7 +360,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
double accuracy_unc;
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
if (high_level) {
|
||||
@@ -351,7 +383,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2;
|
||||
helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
int64_t ndata;
|
||||
@@ -368,13 +400,14 @@ static std::pair<int, int> test_forward_backward(
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
|
||||
helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
float w0;
|
||||
ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float));
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
ggml_opt_alloc(cd.opt_ctx, /*backward =*/ true);
|
||||
// leaked.
|
||||
ggml_opt_eval(cd.opt_ctx, cd.result);
|
||||
}
|
||||
ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float));
|
||||
@@ -405,8 +438,9 @@ static std::pair<int, int> test_forward_backward(
|
||||
{
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == -ndata/2;
|
||||
helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
|
||||
const bool subtest_ok = weights == -ndata * .5;
|
||||
TEST_LOG("%s: ndata=%d weights=%f\n", __func__, (int) ndata, (double) weights);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
int64_t ndata;
|
||||
@@ -423,7 +457,7 @@ static std::pair<int, int> test_forward_backward(
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
|
||||
helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
|
||||
helper_after_test_forward_backward(optim, __func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
helper_free_ctx_data(cd);
|
||||
@@ -431,7 +465,9 @@ static std::pair<int, int> test_forward_backward(
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
static std::pair<int, int> test_epoch_vs_fit(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
@@ -439,21 +475,22 @@ static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched,
|
||||
float weights_fit;
|
||||
|
||||
{
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true);
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true);
|
||||
ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
|
||||
|
||||
ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1);
|
||||
ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr);
|
||||
// leaked.
|
||||
|
||||
ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights));
|
||||
helper_free_ctx_data(cd);
|
||||
}
|
||||
{
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false);
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ false);
|
||||
ggml_opt_dataset_t dataset = cd.dataset_unsupervised;
|
||||
|
||||
ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset,
|
||||
GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
|
||||
ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset, GGML_OPT_LOSS_TYPE_SUM,
|
||||
optim, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true);
|
||||
|
||||
ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights));
|
||||
helper_free_ctx_data(cd);
|
||||
@@ -461,31 +498,27 @@ static std::pair<int, int> test_epoch_vs_fit(ggml_backend_sched_t backend_sched,
|
||||
|
||||
const bool subtest_ok = weights_epoch == weights_fit;
|
||||
|
||||
printf(" %s(): ", __func__);
|
||||
if (subtest_ok) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
npass++;
|
||||
} else {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
}
|
||||
ntest++;
|
||||
print_ok(__func__, subtest_ok, npass, ntest);
|
||||
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
static void helper_after_test_idata_split(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
const char * func, const bool high_level, const int epoch,
|
||||
const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
|
||||
std::string options = ", epoch=";
|
||||
options += std::to_string(epoch);
|
||||
helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass);
|
||||
helper_after_test(optim, func, high_level, options, subtest, subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
|
||||
static std::pair<int, int> test_idata_split(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(optim, backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false);
|
||||
struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx);
|
||||
const int idata_split = ndata * 2/3;
|
||||
|
||||
@@ -494,6 +527,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
|
||||
loss_history[idata] = NAN;
|
||||
}
|
||||
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
for (int epoch = 1; epoch <= 4; ++epoch) {
|
||||
if (high_level) {
|
||||
ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr);
|
||||
@@ -515,13 +549,13 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (adamw) {
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == ndata/2 - epoch*idata_split;
|
||||
helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
if (adamw) {
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result, &ndata_result);
|
||||
bool subtest_ok = ndata_result == idata_split;
|
||||
@@ -536,9 +570,9 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
|
||||
helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
if (adamw) {
|
||||
int64_t ndata_result;
|
||||
ggml_opt_result_ndata(cd.result2, &ndata_result);
|
||||
bool subtest_ok = ndata_result == ndata - idata_split;
|
||||
@@ -553,7 +587,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
|
||||
ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc);
|
||||
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
|
||||
helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
|
||||
helper_after_test_idata_split(optim, __func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
ggml_opt_result_reset(cd.result);
|
||||
@@ -566,6 +600,7 @@ static std::pair<int, int> test_idata_split(ggml_backend_sched_t backend_sched,
|
||||
}
|
||||
|
||||
static void helper_after_test_gradient_accumulation(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch,
|
||||
const std::string subtest, const bool subtest_ok, int & ntest, int & npass) {
|
||||
std::string options = ", nbatch_physical=";
|
||||
@@ -574,15 +609,17 @@ static void helper_after_test_gradient_accumulation(
|
||||
options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum";
|
||||
options += ", epoch=";
|
||||
options += std::to_string(epoch);
|
||||
helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass);
|
||||
helper_after_test(optim, func, false, options, subtest, subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_gradient_accumulation(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
struct helper_ctx_data cd = helper_get_ctx_data(
|
||||
optim,
|
||||
backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type);
|
||||
|
||||
std::vector<float> grad_history(ndata);
|
||||
@@ -590,6 +627,8 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
grad_history[idata] = NAN;
|
||||
}
|
||||
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
if (adamw)
|
||||
for (int epoch = 1; epoch <= 4; ++epoch) {
|
||||
if (nbatch_physical == 1) {
|
||||
for (int idata = 0; idata < ndata; ++idata) {
|
||||
@@ -646,13 +685,14 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
|
||||
helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
if (adamw) {
|
||||
float weights;
|
||||
ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float));
|
||||
const bool subtest_ok = weights == (ndata/2) - epoch;
|
||||
helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
|
||||
helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass);
|
||||
}
|
||||
{
|
||||
int64_t ndata_result;
|
||||
@@ -674,7 +714,7 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc);
|
||||
subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc);
|
||||
|
||||
helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
|
||||
helper_after_test_gradient_accumulation(optim, __func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass);
|
||||
}
|
||||
|
||||
ggml_opt_result_reset(cd.result);
|
||||
@@ -685,13 +725,22 @@ static std::pair<int, int> test_gradient_accumulation(
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
float constexpr g_sgd_lr = 1e-4f;
|
||||
|
||||
int constexpr g_sgd_epochs = 900;
|
||||
|
||||
static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) {
|
||||
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata);
|
||||
int64_t epoch = *(int64_t*)userdata;
|
||||
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
|
||||
result.adamw.alpha = 0.1f;
|
||||
result.sgd.alpha = g_sgd_lr * std::pow(.99, 1000 * (double)epoch / g_sgd_epochs);
|
||||
result.sgd.wd = 1e-10;
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
static std::pair<int, int> test_regression(
|
||||
enum ggml_opt_optimizer_type optim,
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
int ntest = 0;
|
||||
int npass = 0;
|
||||
|
||||
@@ -761,23 +810,25 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
|
||||
ggml_backend_tensor_set(a, &a0, 0, sizeof(float));
|
||||
ggml_backend_tensor_set(b, &b0, 0, sizeof(float));
|
||||
|
||||
ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR,
|
||||
helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true);
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
int64_t const n_epoch = adamw ? 100 : g_sgd_epochs;
|
||||
ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, optim,
|
||||
helper_get_regression_opt_pars, n_epoch, ndata_regression, 0.0f, true);
|
||||
|
||||
{
|
||||
float a_fit;
|
||||
ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float));
|
||||
float b_fit;
|
||||
ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float));
|
||||
const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2);
|
||||
printf(" %s(subtest=weights): ", __func__);
|
||||
if (subtest_ok) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
npass++;
|
||||
} else {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
}
|
||||
ntest++;
|
||||
float tol = adamw ? 1e-2 : 5e-2;
|
||||
const bool aok = almost_equal(a_fit, a_true, tol);
|
||||
if (!aok)
|
||||
TEST_LOG("%s: a_fit=%f a_true=%f\n", __func__, (double)a_fit, (double)a_true);
|
||||
const bool bok = almost_equal(b_fit, b_true, tol);
|
||||
if (!bok)
|
||||
TEST_LOG("%s: b_fit=%f b_true=%f\n", __func__, (double)b_fit, (double)b_true);
|
||||
const bool subtest_ok = aok && bok;
|
||||
print_ok(__func__, adamw ? subtest_ok : true, npass, ntest, "subtest=weights");
|
||||
}
|
||||
|
||||
ggml_backend_buffer_free(buf);
|
||||
@@ -787,17 +838,18 @@ static std::pair<int, int> test_regression(ggml_backend_sched_t backend_sched, g
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) {
|
||||
static std::pair<int, int> test_backend(
|
||||
ggml_backend_sched_t backend_sched, ggml_backend_t backend, enum ggml_opt_optimizer_type optim) {
|
||||
int npass = 0;
|
||||
int ntest = 0;
|
||||
|
||||
for (bool shuffle : {false, true}) {
|
||||
std::pair<int, int> partial = test_dataset(backend_sched, backend, shuffle);
|
||||
std::pair<int, int> partial = test_dataset(optim, backend_sched, backend, shuffle);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
{
|
||||
std::pair<int, int> partial = test_grad(backend_sched, backend);
|
||||
std::pair<int, int> partial = test_grad(optim, backend_sched, backend);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
@@ -807,30 +859,34 @@ static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml
|
||||
continue;
|
||||
}
|
||||
|
||||
std::pair<int, int> partial = test_forward_backward(backend_sched, backend, high_level, shuffle);
|
||||
std::pair<int, int> partial = test_forward_backward(optim, backend_sched, backend, high_level, shuffle);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
}
|
||||
{
|
||||
std::pair<int, int> partial = test_epoch_vs_fit(backend_sched, backend);
|
||||
std::pair<int, int> partial = test_epoch_vs_fit(optim, backend_sched, backend);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
for (bool high_level : {false, true}){
|
||||
std::pair<int, int> partial = test_idata_split(backend_sched, backend, high_level);
|
||||
std::pair<int, int> partial = test_idata_split(optim, backend_sched, backend, high_level);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
for (int32_t nbatch_physical : {2, 1}) {
|
||||
for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) {
|
||||
std::pair<int, int> partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
bool const adamw = optim == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
if (adamw) {
|
||||
for (int32_t nbatch_physical : { 2, 1 }) {
|
||||
for (enum ggml_opt_loss_type loss_type : { GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN }) {
|
||||
std::pair<int, int> partial =
|
||||
test_gradient_accumulation(optim, backend_sched, backend, nbatch_physical, loss_type);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
std::pair<int, int> partial = test_regression(backend_sched, backend);
|
||||
std::pair<int, int> partial = test_regression(optim, backend_sched, backend);
|
||||
npass += partial.first;
|
||||
ntest += partial.second;
|
||||
}
|
||||
@@ -838,7 +894,9 @@ static std::pair<int, int> test_backend(ggml_backend_sched_t backend_sched, ggml
|
||||
return std::make_pair(npass, ntest);
|
||||
}
|
||||
|
||||
|
||||
int main(void) {
|
||||
ggml_log_set(nullptr, nullptr);
|
||||
const size_t dev_count = ggml_backend_dev_count();
|
||||
printf("Testing %zu devices\n\n", dev_count);
|
||||
size_t n_ok = 0;
|
||||
@@ -851,54 +909,62 @@ int main(void) {
|
||||
|
||||
ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL);
|
||||
GGML_ASSERT(backend != NULL);
|
||||
|
||||
#ifndef _MSC_VER
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2);
|
||||
}
|
||||
|
||||
#endif
|
||||
backends.push_back(backend);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
// Put the backend to be tested in front so that it's prioritized:
|
||||
std::vector<ggml_backend_t> backends_modded = {backends[i]};
|
||||
backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
|
||||
size_t n_total = 0;
|
||||
for (enum ggml_opt_optimizer_type optim : { GGML_OPT_OPTIMIZER_TYPE_ADAMW, GGML_OPT_OPTIMIZER_TYPE_SGD }) {
|
||||
for (size_t i = 0; i < dev_count; ++i) {
|
||||
// Put the backend to be tested in front so that it's prioritized:
|
||||
std::vector<ggml_backend_t> backends_modded = { backends[i] };
|
||||
backends_modded.insert(backends_modded.end(), backends.begin(), backends.end());
|
||||
|
||||
ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
|
||||
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
|
||||
ggml_backend_sched_t backend_sched = ggml_backend_sched_new(
|
||||
backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false, true);
|
||||
|
||||
printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i]));
|
||||
printf(" Device description: %s\n", ggml_backend_dev_description(devs[i]));
|
||||
size_t free, total; // NOLINT
|
||||
ggml_backend_dev_memory(devs[i], &free, &total);
|
||||
printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
|
||||
printf("\n");
|
||||
char const* devname = ggml_backend_dev_name(devs[i]);
|
||||
printf("Backend %zu/%zu: %s\n", i + 1, dev_count, devname);
|
||||
printf(" Device description: %s\n", ggml_backend_dev_description(devs[i]));
|
||||
size_t free, total; // NOLINT
|
||||
ggml_backend_dev_memory(devs[i], &free, &total);
|
||||
printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
|
||||
printf("\n");
|
||||
|
||||
std::pair<int, int> result = test_backend(backend_sched, backends[i]);
|
||||
if (optim == GGML_OPT_OPTIMIZER_TYPE_SGD && !strcmp(devname, "Vulkan0"))
|
||||
//TODO: even though backend returns false for currently
|
||||
// unimplemented sgd op, we still need this
|
||||
continue;
|
||||
if (!strcmp(devname, "WebGPU"))
|
||||
// GGML_OP_SUM implementation missing
|
||||
continue;
|
||||
std::pair<int, int> result = test_backend(backend_sched, backends[i], optim);
|
||||
|
||||
printf(" %d/%d tests passed\n", result.first, result.second);
|
||||
printf(" Backend %s: ", ggml_backend_name(backends[i]));
|
||||
if (result.first == result.second) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
n_ok++;
|
||||
} else {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
printf(" %d/%d tests passed\n", result.first, result.second);
|
||||
|
||||
printf(" Backend %s %s: ", ggml_backend_name(backends[i]), ggml_opt_optimizer_name(optim));
|
||||
if (result.first == result.second) {
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
n_ok++;
|
||||
} else {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
}
|
||||
++n_total;
|
||||
printf("\n");
|
||||
ggml_backend_sched_free(backend_sched);
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
ggml_backend_sched_free(backend_sched);
|
||||
}
|
||||
|
||||
for (ggml_backend_t backend : backends) {
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
printf("%zu/%zu backends passed\n", n_ok, dev_count);
|
||||
if (n_ok != dev_count) {
|
||||
printf("\033[1;31mFAIL\033[0m\n");
|
||||
return 1;
|
||||
}
|
||||
printf("\033[1;32mOK\033[0m\n");
|
||||
return 0;
|
||||
printf("%zu/%zu backend*optimizer passed\n", n_ok, n_total);
|
||||
bool ok = n_ok == n_total;
|
||||
print_ok(ok);
|
||||
return ok ? 0 : 1;
|
||||
}
|
||||
|
Reference in New Issue
Block a user