9#include <capnp/schema-parser.h>
19#include <kj/filesystem.h>
27#include <system_error>
31#define PROXY_BIN "mpgen"
32#define PROXY_DECL "mp/proxy.h"
33#define PROXY_TYPES "mp/proxy-types.h"
44template <
typename Reader>
47 for (
const auto annotation : reader.getAnnotations()) {
48 if (annotation.getId() ==
id) {
55template <
typename Reader>
58 for (
const auto annotation : reader.getAnnotations()) {
59 if (annotation.getId() ==
id) {
60 *result = annotation.getValue().getText();
67template <
typename Reader>
70 for (
const auto annotation : reader.getAnnotations()) {
71 if (annotation.getId() ==
id) {
72 *result = annotation.getValue().getInt32();
79static void ForEachMethod(
const capnp::InterfaceSchema& interface,
const std::function<
void(
const capnp::InterfaceSchema& interface,
const capnp::InterfaceSchema::Method)>& callback)
81 for (
const auto super : interface.getSuperclasses()) {
84 for (
const auto method : interface.getMethods()) {
85 callback(interface, method);
92template <class OutputStream, class Array, const char* Enable = decltype(std::declval<Array>().begin())()>
93static OutputStream&
operator<<(OutputStream& os,
const Array& array)
95 os.write(array.begin(), array.size());
101 template <
typename Value>
107 operator std::string()
const {
return m_os.str(); }
111static std::string
Cap(kj::StringPtr str)
113 std::string result = str;
114 if (!result.empty() &&
'a' <= result[0] && result[0] <=
'z') result[0] -=
'a' -
'A';
120 return !(type.isVoid() || type.isBool() || type.isInt8() || type.isInt16() || type.isInt32() || type.isInt64() ||
121 type.isUInt8() || type.isUInt16() || type.isUInt32() || type.isUInt64() || type.isFloat32() ||
122 type.isFloat64() || type.isEnum());
139 kj::StringPtr include_prefix,
140 kj::StringPtr src_file,
141 const std::vector<kj::StringPtr>& import_paths,
142 const kj::ReadableDirectory& src_dir,
143 const std::vector<kj::Own<const kj::ReadableDirectory>>& import_dirs)
145 std::string output_path;
146 if (src_prefix ==
".") {
147 output_path = src_file;
148 }
else if (!src_file.startsWith(src_prefix) || src_file.size() <= src_prefix.size() ||
149 src_file[src_prefix.size()] !=
'/') {
150 throw std::runtime_error(
"src_prefix is not src_file prefix");
152 output_path = src_file.slice(src_prefix.size() + 1);
155 std::string include_path;
156 if (include_prefix ==
".") {
157 include_path = src_file;
158 }
else if (!src_file.startsWith(include_prefix) || src_file.size() <= include_prefix.size() ||
159 src_file[include_prefix.size()] !=
'/') {
160 throw std::runtime_error(
"include_prefix is not src_file prefix");
162 include_path = src_file.slice(include_prefix.size() + 1);
165 std::string include_base = include_path;
166 const std::string::size_type p = include_base.rfind(
'.');
167 if (p != std::string::npos) include_base.erase(p);
169 std::vector<std::string>
args;
170 args.emplace_back(capnp_PREFIX
"/bin/capnp");
171 args.emplace_back(
"compile");
172 args.emplace_back(
"--src-prefix=");
173 args.back().append(src_prefix.cStr(), src_prefix.size());
174 for (
const auto& import_path : import_paths) {
175 args.emplace_back(
"--import-path=");
176 args.back().append(import_path.cStr(), import_path.size());
178 args.emplace_back(
"--output=" capnp_PREFIX
"/bin/capnpc-c++");
179 args.emplace_back(src_file);
180 const int pid = fork();
182 throw std::system_error(errno, std::system_category(),
"fork");
189 throw std::runtime_error(
"Invoking " capnp_PREFIX
"/bin/capnp failed");
192 const capnp::SchemaParser parser;
193 auto directory_pointers = kj::heapArray<const kj::ReadableDirectory*>(import_dirs.size());
194 for (
size_t i = 0; i < import_dirs.size(); ++i) {
195 directory_pointers[i] = import_dirs[i].get();
197 auto file_schema = parser.parseFromDirectory(src_dir, kj::Path::parse(output_path), directory_pointers);
199 std::ofstream cpp_server(output_path +
".proxy-server.c++");
200 cpp_server <<
"// Generated by " PROXY_BIN " from " << src_file <<
"\n\n";
201 cpp_server <<
"#include <" << include_path <<
".proxy-types.h>\n";
202 cpp_server <<
"#include <" <<
PROXY_TYPES <<
">\n\n";
203 cpp_server <<
"namespace mp {\n";
205 std::ofstream cpp_client(output_path +
".proxy-client.c++");
206 cpp_client <<
"// Generated by " PROXY_BIN " from " << src_file <<
"\n\n";
207 cpp_client <<
"#include <" << include_path <<
".proxy-types.h>\n";
208 cpp_client <<
"#include <" <<
PROXY_TYPES <<
">\n\n";
209 cpp_client <<
"namespace mp {\n";
211 std::ofstream cpp_types(output_path +
".proxy-types.c++");
212 cpp_types <<
"// Generated by " PROXY_BIN " from " << src_file <<
"\n\n";
213 cpp_types <<
"#include <" << include_path <<
".proxy-types.h>\n";
214 cpp_types <<
"#include <" <<
PROXY_TYPES <<
">\n\n";
215 cpp_types <<
"namespace mp {\n";
217 std::string guard = output_path;
218 std::transform(guard.begin(), guard.end(), guard.begin(), [](
unsigned char c) ->
unsigned char {
219 if (
'0' <= c && c <=
'9') return c;
220 if (
'A' <= c && c <=
'Z') return c;
221 if (
'a' <= c && c <=
'z') return c -
'a' +
'A';
225 std::ofstream inl(output_path +
".proxy-types.h");
226 inl <<
"// Generated by " PROXY_BIN " from " << src_file <<
"\n\n";
227 inl <<
"#ifndef " << guard <<
"_PROXY_TYPES_H\n";
228 inl <<
"#define " << guard <<
"_PROXY_TYPES_H\n\n";
229 inl <<
"#include <" << include_path <<
".proxy.h>\n";
230 for (
const auto annotation : file_schema.getProto().getAnnotations()) {
232 inl <<
"#include <" << annotation.getValue().getText() <<
">\n";
235 inl <<
"namespace mp {\n";
237 std::ofstream h(output_path +
".proxy.h");
238 h <<
"// Generated by " PROXY_BIN " from " << src_file <<
"\n\n";
239 h <<
"#ifndef " << guard <<
"_PROXY_H\n";
240 h <<
"#define " << guard <<
"_PROXY_H\n\n";
241 h <<
"#include <" << include_path <<
".h>\n";
242 for (
const auto annotation : file_schema.getProto().getAnnotations()) {
244 h <<
"#include <" << annotation.getValue().getText() <<
">\n";
248 h <<
"#if defined(__GNUC__)\n";
249 h <<
"#pragma GCC diagnostic push\n";
250 h <<
"#if !defined(__has_warning)\n";
251 h <<
"#pragma GCC diagnostic ignored \"-Wsuggest-override\"\n";
252 h <<
"#elif __has_warning(\"-Wsuggest-override\")\n";
253 h <<
"#pragma GCC diagnostic ignored \"-Wsuggest-override\"\n";
256 h <<
"namespace mp {\n";
258 kj::StringPtr message_namespace;
261 std::string base_name = include_base;
262 const size_t output_slash = base_name.rfind(
'/');
263 if (output_slash != std::string::npos) {
264 base_name.erase(0, output_slash + 1);
267 std::ostringstream methods;
268 std::set<kj::StringPtr> accessors_done;
269 std::ostringstream accessors;
270 std::ostringstream dec;
271 std::ostringstream def_server;
272 std::ostringstream def_client;
273 std::ostringstream int_client;
274 std::ostringstream def_types;
276 auto add_accessor = [&](kj::StringPtr
name) {
277 if (!accessors_done.insert(
name).second)
return;
278 const std::string cap =
Cap(
name);
279 accessors <<
"struct " << cap <<
"\n";
281 accessors <<
" template<typename S> static auto get(S&& s) -> decltype(s.get" << cap <<
"()) { return s.get" << cap <<
"(); }\n";
282 accessors <<
" template<typename S> static bool has(S&& s) { return s.has" << cap <<
"(); }\n";
283 accessors <<
" template<typename S, typename A> static void set(S&& s, A&& a) { s.set" << cap
284 <<
"(std::forward<A>(a)); }\n";
285 accessors <<
" template<typename S, typename... A> static decltype(auto) init(S&& s, A&&... a) { return s.init"
286 << cap <<
"(std::forward<A>(a)...); }\n";
287 accessors <<
" template<typename S> static bool getWant(S&& s) { return s.getWant" << cap <<
"(); }\n";
288 accessors <<
" template<typename S> static void setWant(S&& s) { s.setWant" << cap <<
"(true); }\n";
289 accessors <<
" template<typename S> static bool getHas(S&& s) { return s.getHas" << cap <<
"(); }\n";
290 accessors <<
" template<typename S> static void setHas(S&& s) { s.setHas" << cap <<
"(true); }\n";
294 for (
const auto node_nested : file_schema.getProto().getNestedNodes()) {
295 kj::StringPtr node_name = node_nested.getName();
296 const auto&
node = file_schema.getNested(node_name);
297 kj::StringPtr proxied_class_type;
300 if (
node.getProto().isStruct()) {
301 const auto& struc =
node.asStruct();
302 std::ostringstream generic_name;
303 generic_name << node_name;
305 bool first_param =
true;
306 for (
const auto param :
node.getProto().getParameters()) {
312 generic_name <<
", ";
314 dec <<
"typename " << param.getName();
315 generic_name <<
"" << param.getName();
317 if (!first_param) generic_name <<
">";
319 dec <<
"struct ProxyStruct<" << message_namespace <<
"::" << generic_name.str() <<
">\n";
321 dec <<
" using Struct = " << message_namespace <<
"::" << generic_name.str() <<
";\n";
322 for (
const auto field : struc.getFields()) {
323 auto field_name = field.getProto().getName();
324 add_accessor(field_name);
325 dec <<
" using " <<
Cap(field_name) <<
"Accessor = Accessor<" << base_name
326 <<
"_fields::" <<
Cap(field_name) <<
", FIELD_IN | FIELD_OUT";
327 if (
BoxedType(field.getType())) dec <<
" | FIELD_BOXED";
330 dec <<
" using Accessors = std::tuple<";
332 for (
const auto field : struc.getFields()) {
337 dec <<
Cap(field.getProto().getName()) <<
"Accessor";
341 dec <<
" static constexpr size_t fields = " << i <<
";\n";
344 if (proxied_class_type.size()) {
345 inl <<
"template<>\n";
346 inl <<
"struct ProxyType<" << proxied_class_type <<
">\n";
349 inl <<
" using Struct = " << message_namespace <<
"::" << node_name <<
";\n";
351 for (
const auto field : struc.getFields()) {
355 auto field_name = field.getProto().getName();
356 auto member_name = field_name;
358 inl <<
" static decltype(auto) get(std::integral_constant<size_t, " << i <<
">) { return "
359 <<
"&" << proxied_class_type <<
"::" << member_name <<
"; }\n";
362 inl <<
" static constexpr size_t fields = " << i <<
";\n";
367 if (proxied_class_type.size() &&
node.getProto().isInterface()) {
368 const auto&
interface = node.asInterface();
370 std::ostringstream client;
371 client <<
"template<>\nstruct ProxyClient<" << message_namespace <<
"::" << node_name <<
"> final : ";
372 client <<
"public ProxyClientCustom<" << message_namespace <<
"::" << node_name <<
", "
373 << proxied_class_type <<
">\n{\n";
374 client <<
"public:\n";
375 client <<
" using ProxyClientCustom::ProxyClientCustom;\n";
376 client <<
" ~ProxyClient();\n";
378 std::ostringstream server;
379 server <<
"template<>\nstruct ProxyServer<" << message_namespace <<
"::" << node_name <<
"> : public "
380 <<
"ProxyServerCustom<" << message_namespace <<
"::" << node_name <<
", " << proxied_class_type
382 server <<
"public:\n";
383 server <<
" using ProxyServerCustom::ProxyServerCustom;\n";
384 server <<
" ~ProxyServer();\n";
386 const std::ostringstream client_construct;
387 const std::ostringstream client_destroy;
389 int method_ordinal = 0;
390 ForEachMethod(interface, [&] (
const capnp::InterfaceSchema& method_interface,
const capnp::InterfaceSchema::Method& method) {
391 const kj::StringPtr method_name = method.getProto().getName();
392 kj::StringPtr proxied_method_name = method_name;
395 const std::string method_prefix =
Format() << message_namespace <<
"::" << method_interface.getShortDisplayName()
396 <<
"::" <<
Cap(method_name);
397 const bool is_construct = method_name ==
"construct";
398 const bool is_destroy = method_name ==
"destroy";
402 ::capnp::StructSchema::Field param;
403 bool param_is_set =
false;
404 ::capnp::StructSchema::Field result;
405 bool result_is_set =
false;
408 bool optional =
false;
409 bool requested =
false;
411 kj::StringPtr exception;
414 std::vector<Field> fields;
415 std::map<kj::StringPtr, int> field_idx;
416 bool has_result =
false;
418 auto add_field = [&](const ::capnp::StructSchema::Field& schema_field,
bool param) {
423 auto field_name = schema_field.getProto().getName();
424 auto inserted = field_idx.emplace(field_name, fields.size());
425 if (inserted.second) {
426 fields.emplace_back();
428 auto& field = fields[inserted.first->second];
430 field.param = schema_field;
431 field.param_is_set =
true;
433 field.result = schema_field;
434 field.result_is_set =
true;
437 if (!param && field_name ==
"result") {
446 if (schema_field.getType().isStruct()) {
449 }
else if (schema_field.getType().isInterface()) {
456 if (inserted.second && !field.retval && !field.exception.size()) {
461 for (
const auto schema_field : method.getParamType().getFields()) {
462 add_field(schema_field,
true);
464 for (
const auto schema_field : method.getResultType().getFields()) {
465 add_field(schema_field,
false);
467 for (
auto& field : field_idx) {
468 auto has_field = field_idx.find(
"has" +
Cap(field.first));
469 if (has_field != field_idx.end()) {
470 fields[has_field->second].skip =
true;
471 fields[field.second].optional =
true;
473 auto want_field = field_idx.find(
"want" +
Cap(field.first));
474 if (want_field != field_idx.end() && fields[want_field->second].param_is_set) {
475 fields[want_field->second].skip =
true;
476 fields[field.second].requested =
true;
480 if (!is_construct && !is_destroy && (&method_interface == &interface)) {
481 methods <<
"template<>\n";
482 methods <<
"struct ProxyMethod<" << method_prefix <<
"Params>\n";
484 methods <<
" static constexpr auto impl = &" << proxied_class_type
485 <<
"::" << proxied_method_name <<
";\n";
489 std::ostringstream client_args;
490 std::ostringstream client_invoke;
491 std::ostringstream server_invoke_start;
492 std::ostringstream server_invoke_end;
494 for (
const auto& field : fields) {
495 if (field.skip)
continue;
497 const auto& f = field.param_is_set ? field.param : field.result;
498 auto field_name = f.getProto().getName();
499 auto field_type = f.getType();
501 std::ostringstream field_flags;
502 if (!field.param_is_set) {
503 field_flags <<
"FIELD_OUT";
504 }
else if (field.result_is_set) {
505 field_flags <<
"FIELD_IN | FIELD_OUT";
507 field_flags <<
"FIELD_IN";
509 if (field.optional) field_flags <<
" | FIELD_OPTIONAL";
510 if (field.requested) field_flags <<
" | FIELD_REQUESTED";
511 if (
BoxedType(field_type)) field_flags <<
" | FIELD_BOXED";
513 add_accessor(field_name);
515 for (
int i = 0; i < field.args; ++i) {
516 if (argc > 0) client_args <<
",";
517 client_args <<
"M" << method_ordinal <<
"::Param<" << argc <<
"> " << field_name;
518 if (field.args > 1) client_args << i;
521 client_invoke <<
", ";
523 if (field.exception.size()) {
524 client_invoke <<
"ClientException<" << field.exception <<
", ";
526 client_invoke <<
"MakeClientParam<";
529 client_invoke <<
"Accessor<" << base_name <<
"_fields::" <<
Cap(field_name) <<
", "
530 << field_flags.str() <<
">>(";
532 if (field.retval || field.args == 1) {
533 client_invoke << field_name;
535 for (
int i = 0; i < field.args; ++i) {
536 if (i > 0) client_invoke <<
", ";
537 client_invoke << field_name << i;
540 client_invoke <<
")";
542 if (field.exception.size()) {
543 server_invoke_start <<
"Make<ServerExcept, " << field.exception;
544 }
else if (field.retval) {
545 server_invoke_start <<
"Make<ServerRet";
547 server_invoke_start <<
"MakeServerField<" << field.args;
549 server_invoke_start <<
", Accessor<" << base_name <<
"_fields::" <<
Cap(field_name) <<
", "
550 << field_flags.str() <<
">>(";
551 server_invoke_end <<
")";
554 const std::string static_str{is_construct || is_destroy ?
"static " :
""};
555 const std::string super_str{is_construct || is_destroy ?
"Super& super" :
""};
556 const std::string self_str{is_construct || is_destroy ?
"super" :
"*this"};
558 client <<
" using M" << method_ordinal <<
" = ProxyClientMethodTraits<" << method_prefix
560 client <<
" " << static_str <<
"typename M" << method_ordinal <<
"::Result " << method_name <<
"("
561 << super_str << client_args.str() <<
")";
563 def_client <<
"ProxyClient<" << message_namespace <<
"::" << node_name <<
">::M" << method_ordinal
564 <<
"::Result ProxyClient<" << message_namespace <<
"::" << node_name <<
">::" << method_name
565 <<
"(" << super_str << client_args.str() <<
") {\n";
567 def_client <<
" typename M" << method_ordinal <<
"::Result result;\n";
569 def_client <<
" clientInvoke(" << self_str <<
", &" << message_namespace <<
"::" << node_name
570 <<
"::Client::" << method_name <<
"Request" << client_invoke.str() <<
");\n";
571 if (has_result) def_client <<
" return result;\n";
574 server <<
" kj::Promise<void> " << method_name <<
"(" <<
Cap(method_name)
575 <<
"Context call_context) override;\n";
577 def_server <<
"kj::Promise<void> ProxyServer<" << message_namespace <<
"::" << node_name
578 <<
">::" << method_name <<
"(" <<
Cap(method_name)
579 <<
"Context call_context) {\n"
580 " return serverInvoke(*this, call_context, "
581 << server_invoke_start.str();
583 def_server <<
"ServerDestroy()";
585 def_server <<
"ServerCall()";
587 def_server << server_invoke_end.str() <<
");\n}\n";
593 dec <<
"\n" << client.str() <<
"\n" << server.str() <<
"\n";
594 KJ_IF_MAYBE(bracket, proxied_class_type.findFirst(
'<')) {
599 dec <<
"template<>\nstruct ProxyType<" << proxied_class_type <<
">\n{\n";
600 dec <<
" using Type = " << proxied_class_type <<
";\n";
601 dec <<
" using Message = " << message_namespace <<
"::" << node_name <<
";\n";
602 dec <<
" using Client = ProxyClient<Message>;\n";
603 dec <<
" using Server = ProxyServer<Message>;\n";
605 int_client <<
"ProxyTypeRegister t" << node_nested.getId() <<
"{TypeList<" << proxied_class_type <<
">{}};\n";
607 def_types <<
"ProxyClient<" << message_namespace <<
"::" << node_name
608 <<
">::~ProxyClient() { clientDestroy(*this); " << client_destroy.str() <<
" }\n";
609 def_types <<
"ProxyServer<" << message_namespace <<
"::" << node_name
610 <<
">::~ProxyServer() { serverDestroy(*this); }\n";
614 h << methods.str() <<
"namespace " << base_name <<
"_fields {\n"
615 << accessors.str() <<
"} // namespace " << base_name <<
"_fields\n"
618 cpp_server << def_server.str();
619 cpp_server <<
"} // namespace mp\n";
621 cpp_client << def_client.str();
622 cpp_client <<
"namespace {\n" << int_client.str() <<
"} // namespace\n";
623 cpp_client <<
"} // namespace mp\n";
625 cpp_types << def_types.str();
626 cpp_types <<
"} // namespace mp\n";
628 inl <<
"} // namespace mp\n";
631 h <<
"} // namespace mp\n";
632 h <<
"#if defined(__GNUC__)\n";
633 h <<
"#pragma GCC diagnostic pop\n";
641 std::cerr <<
"Usage: " <<
PROXY_BIN <<
" SRC_PREFIX INCLUDE_PREFIX SRC_FILE [IMPORT_PATH...]\n";
644 std::vector<kj::StringPtr> import_paths;
645 std::vector<kj::Own<const kj::ReadableDirectory>> import_dirs;
646 auto fs = kj::newDiskFilesystem();
647 auto cwd = fs->getCurrentPath();
648 kj::Own<const kj::ReadableDirectory> src_dir;
649 KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(argv[1]))) {
650 src_dir = kj::mv(*dir);
652 throw std::runtime_error(std::string(
"Failed to open src_prefix prefix directory: ") + argv[1]);
654 for (
int i = 4; i < argc; ++i) {
655 KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(argv[i]))) {
656 import_paths.emplace_back(argv[i]);
657 import_dirs.emplace_back(kj::mv(*dir));
659 throw std::runtime_error(std::string(
"Failed to open import directory: ") + argv[i]);
662 for (
const char* path : {CMAKE_INSTALL_PREFIX
"/include", capnp_PREFIX
"/include"}) {
663 KJ_IF_MAYBE(dir, fs->getRoot().tryOpenSubdir(cwd.evalNative(path))) {
664 import_paths.emplace_back(path);
665 import_dirs.emplace_back(kj::mv(*dir));
669 Generate(argv[1], argv[2], argv[3], import_paths, *src_dir, import_dirs);
kj::ArrayPtr< const char > CharSlice
constexpr uint64_t EXCEPTION_ANNOTATION_ID
constexpr uint64_t NAME_ANNOTATION_ID
static OutputStream & operator<<(OutputStream &os, const Array &array)
static void ForEachMethod(const capnp::InterfaceSchema &interface, const std::function< void(const capnp::InterfaceSchema &interface, const capnp::InterfaceSchema::Method)> &callback)
constexpr uint64_t SKIP_ANNOTATION_ID
int main(int argc, char **argv)
constexpr uint64_t WRAP_ANNOTATION_ID
constexpr uint64_t INCLUDE_ANNOTATION_ID
static void Generate(kj::StringPtr src_prefix, kj::StringPtr include_prefix, kj::StringPtr src_file, const std::vector< kj::StringPtr > &import_paths, const kj::ReadableDirectory &src_dir, const std::vector< kj::Own< const kj::ReadableDirectory > > &import_dirs)
static bool GetAnnotationInt32(const Reader &reader, uint64_t id, int32_t *result)
static bool BoxedType(const ::capnp::Type &type)
constexpr uint64_t INCLUDE_TYPES_ANNOTATION_ID
constexpr uint64_t NAMESPACE_ANNOTATION_ID
static bool AnnotationExists(const Reader &reader, uint64_t id)
static std::string Cap(kj::StringPtr str)
static bool GetAnnotationText(const Reader &reader, uint64_t id, kj::StringPtr *result)
constexpr uint64_t COUNT_ANNOTATION_ID
int WaitProcess(int pid)
Wait for a process to exit and return its exit code.
void ExecProcess(const std::vector< std::string > &args)
Call execvp with vector args.