From 3d443bd0a7e1d9eebfa37321fc8118d8d538af13 Mon Sep 17 00:00:00 2001
From: Roeeeee <35409124+5c077m4n@users.noreply.github.com>
Date: Thu, 16 May 2024 20:03:16 +0300
Subject: [PATCH] feat(persistence): add `pre-` and `post-` load hooks (#24)

* Add optional hook typings to config object

* Add optional hooks conditional calls

* Add hooks description to main readme
---
 README.md                  |  3 +++
 lua/persistence/config.lua |  3 +++
 lua/persistence/init.lua   | 12 ++++++++++++
 3 files changed, 18 insertions(+)

diff --git a/README.md b/README.md
index 1b7b8de..fd355cc 100644
--- a/README.md
+++ b/README.md
@@ -37,7 +37,10 @@ Persistence comes with the following defaults:
   dir = vim.fn.expand(vim.fn.stdpath("state") .. "/sessions/"), -- directory where session files are saved
   options = { "buffers", "curdir", "tabpages", "winsize" }, -- sessionoptions used for saving
   pre_save = nil, -- a function to call before saving the session
+  post_save = nil, -- a function to call after saving the session
   save_empty = false, -- don't save if there are no open file buffers
+  pre_load = nil, -- a function to call before loading the session
+  post_load = nil, -- a function to call after loading the session
 }
 ```
 
diff --git a/lua/persistence/config.lua b/lua/persistence/config.lua
index 3bc4b1b..1612f21 100644
--- a/lua/persistence/config.lua
+++ b/lua/persistence/config.lua
@@ -2,6 +2,9 @@ local M = {}
 
 ---@class PersistenceOptions
 ---@field pre_save? fun()
+---@field post_save? fun()
+---@field pre_load? fun()
+---@field post_load? fun()
 local defaults = {
   dir = vim.fn.expand(vim.fn.stdpath("state") .. "/sessions/"), -- directory where session files are saved
   options = { "buffers", "curdir", "tabpages", "winsize", "skiprtp" }, -- sessionoptions used for saving
diff --git a/lua/persistence/init.lua b/lua/persistence/init.lua
index 59c3d40..d6ac1f5 100644
--- a/lua/persistence/init.lua
+++ b/lua/persistence/init.lua
@@ -56,6 +56,10 @@ function M.start()
       end
 
       M.save()
+
+      if type(Config.options.post_save) == "function" then
+        Config.options.post_save()
+      end
     end,
   })
 end
@@ -76,7 +80,15 @@ function M.load(opt)
   opt = opt or {}
   local sfile = opt.last and M.get_last() or M.get_current()
   if sfile and vim.fn.filereadable(sfile) ~= 0 then
+    if type(Config.options.pre_load) == "function" then
+      Config.options.pre_load()
+    end
+
     vim.cmd("silent! source " .. e(sfile))
+
+    if type(Config.options.post_load) == "function" then
+      Config.options.post_load()
+    end
   end
 end