mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-09-02 13:24:28 -04:00
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs) * Vulkan: Implement GGML_OP_OPT_STEP_SGD * tests: Fix OPT_STEP_SGD test-backend-ops * SGD op param store weight-decay and not 1-alpha*wd * minor + cosmetic changes * fix vulkan sgd * try CI fix --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -74,16 +74,26 @@ extern "C" {
|
||||
GGML_OPT_BUILD_TYPE_OPT = 30,
|
||||
};
|
||||
|
||||
enum ggml_opt_optimizer_type {
|
||||
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
GGML_OPT_OPTIMIZER_TYPE_SGD,
|
||||
|
||||
GGML_OPT_OPTIMIZER_TYPE_COUNT
|
||||
};
|
||||
|
||||
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
|
||||
struct ggml_opt_optimizer_params {
|
||||
// AdamW optimizer parameters
|
||||
struct {
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
float beta1; // first AdamW momentum
|
||||
float beta2; // second AdamW momentum
|
||||
float eps; // epsilon for numerical stability
|
||||
float wd; // weight decay for AdamW, use 0.0f to disable
|
||||
float wd; // weight decay - 0.0f to disable
|
||||
} adamw;
|
||||
struct {
|
||||
float alpha; // learning rate
|
||||
float wd; // weight decay
|
||||
} sgd;
|
||||
};
|
||||
|
||||
// callback to calculate optimizer parameters prior to a backward pass
|
||||
@@ -112,8 +122,11 @@ extern "C" {
|
||||
|
||||
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
|
||||
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
|
||||
|
||||
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW needs m, v momenta per parameter tensor
|
||||
enum ggml_opt_optimizer_type optimizer;
|
||||
};
|
||||
|
||||
// get parameters for an optimization context with defaults set where possible
|
||||
@@ -142,6 +155,10 @@ extern "C" {
|
||||
// get the gradient accumulator for a node from the forward graph
|
||||
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
|
||||
|
||||
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t); //TODO consistent naming scheme
|
||||
|
||||
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
|
||||
|
||||
// ====== Optimization Result ======
|
||||
|
||||
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
|
||||
@@ -226,12 +243,14 @@ extern "C" {
|
||||
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
|
||||
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
|
||||
enum ggml_opt_loss_type loss_type, // loss to minimize
|
||||
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
|
||||
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
|
||||
int64_t nepoch, // how many times the dataset should be iterated over
|
||||
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
|
||||
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
|
||||
bool silent); // whether or not info prints to stderr should be suppressed
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
Reference in New Issue
Block a user