diff --git a/common/minja/chat-template.hpp b/common/minja/chat-template.hpp index 237a7625c..c930a587a 100644 --- a/common/minja/chat-template.hpp +++ b/common/minja/chat-template.hpp @@ -13,10 +13,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -393,8 +395,8 @@ class chat_template { for (const auto & message_ : adjusted_messages) { auto message = message_; - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) { + throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump()); } std::string role = message.at("role"); @@ -415,7 +417,6 @@ class chat_template { } } if (polyfill_tool_calls) { - auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { if (tool_call.at("type") != "function") { @@ -434,8 +435,11 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && !content.empty()) { - obj["content"] = content; + if (message.contains("content")) { + auto content = message.at("content"); + if (!content.is_null() && !content.empty()) { + obj["content"] = content; + } } message["content"] = obj.dump(2); message.erase("tool_calls"); diff --git a/common/minja/minja.hpp b/common/minja/minja.hpp index e52e792d8..b3b00547d 100644 --- a/common/minja/minja.hpp +++ b/common/minja/minja.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -233,7 +234,7 @@ public: } } else if (is_object()) { if (!index.is_hashable()) - throw std::runtime_error("Unashable type: " + index.dump()); + throw std::runtime_error("Unhashable type: " + index.dump()); auto it = object_->find(index.primitive_); if (it == object_->end()) throw std::runtime_error("Key not found: " + index.dump()); @@ -252,7 +253,7 @@ public: auto index = key.get(); return array_->at(index < 0 ? array_->size() + index : index); } else if (object_) { - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); auto it = object_->find(key.primitive_); if (it == object_->end()) return Value(); return it->second; @@ -261,7 +262,7 @@ public: } void set(const Value& key, const Value& value) { if (!object_) throw std::runtime_error("Value is not an object: " + dump()); - if (!key.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!key.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); (*object_)[key.primitive_] = value; } Value call(const std::shared_ptr & context, ArgumentsValue & args) const { @@ -398,7 +399,7 @@ public: } return false; } else if (object_) { - if (!value.is_hashable()) throw std::runtime_error("Unashable type: " + value.dump()); + if (!value.is_hashable()) throw std::runtime_error("Unhashable type: " + value.dump()); return object_->find(value.primitive_) != object_->end(); } else { throw std::runtime_error("contains can only be called on arrays and objects: " + dump()); @@ -416,7 +417,7 @@ public: return const_cast(this)->at(index); } Value& at(const Value & index) { - if (!index.is_hashable()) throw std::runtime_error("Unashable type: " + dump()); + if (!index.is_hashable()) throw std::runtime_error("Unhashable type: " + dump()); if (is_array()) return array_->at(index.get()); if (is_object()) return object_->at(index.primitive_); throw std::runtime_error("Value is not an array or object: " + dump()); @@ -676,8 +677,8 @@ public: class VariableExpr : public Expression { std::string name; public: - VariableExpr(const Location & location, const std::string& n) - : Expression(location), name(n) {} + VariableExpr(const Location & loc, const std::string& n) + : Expression(loc), name(n) {} std::string get_name() const { return name; } Value do_evaluate(const std::shared_ptr & context) const override { if (!context->contains(name)) { @@ -1200,9 +1201,9 @@ public: class SliceExpr : public Expression { public: - std::shared_ptr start, end; - SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e) - : Expression(loc), start(std::move(s)), end(std::move(e)) {} + std::shared_ptr start, end, step; + SliceExpr(const Location & loc, std::shared_ptr && s, std::shared_ptr && e, std::shared_ptr && st = nullptr) + : Expression(loc), start(std::move(s)), end(std::move(e)), step(std::move(st)) {} Value do_evaluate(const std::shared_ptr &) const override { throw std::runtime_error("SliceExpr not implemented"); } @@ -1219,18 +1220,35 @@ public: if (!index) throw std::runtime_error("SubscriptExpr.index is null"); auto target_value = base->evaluate(context); if (auto slice = dynamic_cast(index.get())) { - auto start = slice->start ? slice->start->evaluate(context).get() : 0; - auto end = slice->end ? slice->end->evaluate(context).get() : (int64_t) target_value.size(); + auto len = target_value.size(); + auto wrap = [len](int64_t i) -> int64_t { + if (i < 0) { + return i + len; + } + return i; + }; + int64_t step = slice->step ? slice->step->evaluate(context).get() : 1; + if (!step) { + throw std::runtime_error("slice step cannot be zero"); + } + int64_t start = slice->start ? wrap(slice->start->evaluate(context).get()) : (step < 0 ? len - 1 : 0); + int64_t end = slice->end ? wrap(slice->end->evaluate(context).get()) : (step < 0 ? -1 : len); if (target_value.is_string()) { std::string s = target_value.get(); - if (start < 0) start = s.size() + start; - if (end < 0) end = s.size() + end; - return s.substr(start, end - start); + + std::string result; + if (start < end && step == 1) { + result = s.substr(start, end - start); + } else { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { + result += s[i]; + } + } + return result; + } else if (target_value.is_array()) { - if (start < 0) start = target_value.size() + start; - if (end < 0) end = target_value.size() + end; auto result = Value::array(); - for (auto i = start; i < end; ++i) { + for (int64_t i = start; step > 0 ? i < end : i > end; i += step) { result.push_back(target_value.at(i)); } return result; @@ -1305,6 +1323,8 @@ public: if (name == "iterable") return l.is_iterable(); if (name == "sequence") return l.is_array(); if (name == "defined") return !l.is_null(); + if (name == "true") return l.to_bool(); + if (name == "false") return !l.to_bool(); throw std::runtime_error("Unknown type for 'is' operator: " + name); }; auto value = eval(); @@ -1520,6 +1540,10 @@ public: vargs.expectArgs("endswith method", {1, 1}, {0, 0}); auto suffix = vargs.args[0].get(); return suffix.length() <= str.length() && std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); + } else if (method->get_name() == "startswith") { + vargs.expectArgs("startswith method", {1, 1}, {0, 0}); + auto prefix = vargs.args[0].get(); + return prefix.length() <= str.length() && std::equal(prefix.begin(), prefix.end(), str.begin()); } else if (method->get_name() == "title") { vargs.expectArgs("title method", {0, 0}, {0, 0}); auto res = str; @@ -2082,28 +2106,37 @@ private: while (it != end && consumeSpaces() && peekSymbols({ "[", "." })) { if (!consumeToken("[").empty()) { - std::shared_ptr index; + std::shared_ptr index; + auto slice_loc = get_location(); + std::shared_ptr start, end, step; + bool has_first_colon = false, has_second_colon = false; + + if (!peekSymbols({ ":" })) { + start = parseExpression(); + } + + if (!consumeToken(":").empty()) { + has_first_colon = true; + if (!peekSymbols({ ":", "]" })) { + end = parseExpression(); + } if (!consumeToken(":").empty()) { - auto slice_end = parseExpression(); - index = std::make_shared(slice_end->location, nullptr, std::move(slice_end)); - } else { - auto slice_start = parseExpression(); - if (!consumeToken(":").empty()) { - consumeSpaces(); - if (peekSymbols({ "]" })) { - index = std::make_shared(slice_start->location, std::move(slice_start), nullptr); - } else { - auto slice_end = parseExpression(); - index = std::make_shared(slice_start->location, std::move(slice_start), std::move(slice_end)); - } - } else { - index = std::move(slice_start); + has_second_colon = true; + if (!peekSymbols({ "]" })) { + step = parseExpression(); } } - if (!index) throw std::runtime_error("Empty index in subscript"); - if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + } - value = std::make_shared(value->location, std::move(value), std::move(index)); + if ((has_first_colon || has_second_colon) && (start || end || step)) { + index = std::make_shared(slice_loc, std::move(start), std::move(end), std::move(step)); + } else { + index = std::move(start); + } + if (!index) throw std::runtime_error("Empty index in subscript"); + if (consumeToken("]").empty()) throw std::runtime_error("Expected closing bracket in subscript"); + + value = std::make_shared(value->location, std::move(value), std::move(index)); } else if (!consumeToken(".").empty()) { auto identifier = parseIdentifier(); if (!identifier) throw std::runtime_error("Expected identifier in subscript");