tinc_build/
lib.rs

1//! The code generator for [`tinc`](https://crates.io/crates/tinc).
2#![cfg_attr(feature = "docs", doc = "## Feature flags")]
3#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
4//! ## Usage
5//!
6//! In your `build.rs`:
7//!
8//! ```rust,no_run
9//! # #[allow(clippy::needless_doctest_main)]
10//! fn main() {
11//!     tinc_build::Config::prost()
12//!         .compile_protos(&["proto/test.proto"], &["proto"])
13//!         .unwrap();
14//! }
15//! ```
16//!
17//! Look at [`Config`] to see different options to configure the generator.
18//!
19//! ## License
20//!
21//! This project is licensed under the MIT or Apache-2.0 license.
22//! You can choose between one of them if you use this work.
23//!
24//! `SPDX-License-Identifier: MIT OR Apache-2.0`
25#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
26#![cfg_attr(docsrs, feature(doc_auto_cfg))]
27#![deny(missing_docs)]
28#![deny(unsafe_code)]
29#![deny(unreachable_pub)]
30#![cfg_attr(not(feature = "prost"), allow(unused_variables, dead_code))]
31
32use std::io::ErrorKind;
33use std::path::{Path, PathBuf};
34
35use anyhow::Context;
36use extern_paths::ExternPaths;
37mod codegen;
38mod extern_paths;
39
40#[cfg(feature = "prost")]
41mod prost_explore;
42
43mod types;
44
45/// The mode to use for the generator, currently we only support `prost` codegen.
46#[derive(Debug, Clone, Copy)]
47pub enum Mode {
48    /// Use `prost` to generate the protobuf structures
49    #[cfg(feature = "prost")]
50    Prost,
51}
52
53impl quote::ToTokens for Mode {
54    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
55        match self {
56            #[cfg(feature = "prost")]
57            Mode::Prost => quote::quote!(prost).to_tokens(tokens),
58            #[cfg(not(feature = "prost"))]
59            _ => unreachable!(),
60        }
61    }
62}
63
64#[derive(Default, Debug)]
65struct PathConfigs {
66    btree_maps: Vec<String>,
67    bytes: Vec<String>,
68    boxed: Vec<String>,
69}
70
71/// A config for configuring how tinc builds / generates code.
72#[derive(Debug)]
73pub struct Config {
74    disable_tinc_include: bool,
75    root_module: bool,
76    mode: Mode,
77    paths: PathConfigs,
78    extern_paths: ExternPaths,
79    out_dir: PathBuf,
80}
81
82impl Config {
83    /// New config with prost mode.
84    #[cfg(feature = "prost")]
85    pub fn prost() -> Self {
86        Self::new(Mode::Prost)
87    }
88
89    /// Make a new config with a given mode.
90    pub fn new(mode: Mode) -> Self {
91        Self::new_with_out_dir(mode, std::env::var_os("OUT_DIR").expect("OUT_DIR not set"))
92    }
93
94    /// Make a new config with a given mode.
95    pub fn new_with_out_dir(mode: Mode, out_dir: impl Into<PathBuf>) -> Self {
96        Self {
97            disable_tinc_include: false,
98            mode,
99            paths: PathConfigs::default(),
100            extern_paths: ExternPaths::new(mode),
101            root_module: true,
102            out_dir: out_dir.into(),
103        }
104    }
105
106    /// Disable tinc auto-include. By default tinc will add its own
107    /// annotations into the include path of protoc.
108    pub fn disable_tinc_include(&mut self) -> &mut Self {
109        self.disable_tinc_include = true;
110        self
111    }
112
113    /// Disable the root module generation
114    /// which allows for `tinc::include_protos!()` without
115    /// providing a package.
116    pub fn disable_root_module(&mut self) -> &mut Self {
117        self.root_module = false;
118        self
119    }
120
121    /// Specify a path to generate a `BTreeMap` instead of a `HashMap` for proto map.
122    pub fn btree_map(&mut self, path: impl std::fmt::Display) -> &mut Self {
123        self.paths.btree_maps.push(path.to_string());
124        self
125    }
126
127    /// Specify a path to generate `bytes::Bytes` instead of `Vec<u8>` for proto bytes.
128    pub fn bytes(&mut self, path: impl std::fmt::Display) -> &mut Self {
129        self.paths.bytes.push(path.to_string());
130        self
131    }
132
133    /// Specify a path to wrap around a `Box` instead of including it directly into the struct.
134    pub fn boxed(&mut self, path: impl std::fmt::Display) -> &mut Self {
135        self.paths.boxed.push(path.to_string());
136        self
137    }
138
139    /// Compile and generate all the protos with the includes.
140    pub fn compile_protos(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
141        match self.mode {
142            #[cfg(feature = "prost")]
143            Mode::Prost => self.compile_protos_prost(protos, includes),
144        }
145    }
146
147    /// Generate tinc code based on a precompiled FileDescriptorSet.
148    pub fn load_fds(&mut self, fds: impl bytes::Buf) -> anyhow::Result<()> {
149        match self.mode {
150            #[cfg(feature = "prost")]
151            Mode::Prost => self.load_fds_prost(fds),
152        }
153    }
154
155    #[cfg(feature = "prost")]
156    fn compile_protos_prost(&mut self, protos: &[impl AsRef<Path>], includes: &[impl AsRef<Path>]) -> anyhow::Result<()> {
157        let fd_path = self.out_dir.join("tinc.fd.bin");
158
159        let mut config = prost_build::Config::new();
160        config.file_descriptor_set_path(&fd_path);
161
162        let mut includes = includes.iter().map(|i| i.as_ref()).collect::<Vec<_>>();
163
164        {
165            let tinc_out = self.out_dir.join("tinc");
166            std::fs::create_dir_all(&tinc_out).context("failed to create tinc directory")?;
167            std::fs::write(tinc_out.join("annotations.proto"), tinc_pb_prost::TINC_ANNOTATIONS)
168                .context("failed to write tinc_annotations.rs")?;
169            includes.push(&self.out_dir);
170        }
171
172        config.load_fds(protos, &includes).context("failed to generate tonic fds")?;
173        let fds_bytes = std::fs::read(fd_path).context("failed to read tonic fds")?;
174        self.load_fds_prost(fds_bytes.as_slice())
175    }
176
177    #[cfg(feature = "prost")]
178    fn load_fds_prost(&mut self, fds: impl bytes::Buf) -> anyhow::Result<()> {
179        use std::collections::BTreeMap;
180
181        use codegen::prost_sanatize::to_snake;
182        use codegen::utils::get_common_import_path;
183        use proc_macro2::Span;
184        use prost::Message;
185        use prost_reflect::DescriptorPool;
186        use prost_types::FileDescriptorSet;
187        use quote::{ToTokens, quote};
188        use syn::parse_quote;
189        use types::{ProtoPath, ProtoTypeRegistry};
190
191        let pool = DescriptorPool::decode(fds).context("failed to add tonic fds")?;
192
193        let mut registry = ProtoTypeRegistry::new(self.mode, self.extern_paths.clone());
194
195        let mut config = prost_build::Config::new();
196
197        // This option is provided to make sure prost_build does not internally
198        // set extern_paths. We manage that via a re-export of prost_types in the
199        // tinc crate.
200        config.compile_well_known_types();
201
202        config.btree_map(self.paths.btree_maps.iter());
203        self.paths.boxed.iter().for_each(|path| {
204            config.boxed(path);
205        });
206        config.bytes(self.paths.bytes.iter());
207
208        for (proto, rust) in self.extern_paths.paths() {
209            let proto = if proto.starts_with('.') {
210                proto.to_string()
211            } else {
212                format!(".{proto}")
213            };
214            config.extern_path(proto, rust.to_token_stream().to_string());
215        }
216
217        prost_explore::Extensions::new(&pool)
218            .process(&mut registry)
219            .context("failed to process extensions")?;
220
221        let mut packages = codegen::generate_modules(&registry)?;
222
223        packages.iter_mut().for_each(|(path, package)| {
224            if self.extern_paths.contains(path) {
225                return;
226            }
227
228            package.enum_configs().for_each(|(path, enum_config)| {
229                if self.extern_paths.contains(path) {
230                    return;
231                }
232
233                enum_config.attributes().for_each(|attribute| {
234                    config.enum_attribute(path, attribute.to_token_stream().to_string());
235                });
236                enum_config.variants().for_each(|variant| {
237                    let path = format!("{path}.{variant}");
238                    enum_config.variant_attributes(variant).for_each(|attribute| {
239                        config.field_attribute(&path, attribute.to_token_stream().to_string());
240                    });
241                });
242            });
243
244            package.message_configs().for_each(|(path, message_config)| {
245                if self.extern_paths.contains(path) {
246                    return;
247                }
248
249                message_config.attributes().for_each(|attribute| {
250                    config.message_attribute(path, attribute.to_token_stream().to_string());
251                });
252                message_config.fields().for_each(|field| {
253                    let path = format!("{path}.{field}");
254                    message_config.field_attributes(field).for_each(|attribute| {
255                        config.field_attribute(&path, attribute.to_token_stream().to_string());
256                    });
257                });
258                message_config.oneof_configs().for_each(|(field, oneof_config)| {
259                    let path = format!("{path}.{field}");
260                    oneof_config.attributes().for_each(|attribute| {
261                        // In prost oneofs (container) are treated as enums
262                        config.enum_attribute(&path, attribute.to_token_stream().to_string());
263                    });
264                    oneof_config.fields().for_each(|field| {
265                        let path = format!("{path}.{field}");
266                        oneof_config.field_attributes(field).for_each(|attribute| {
267                            config.field_attribute(&path, attribute.to_token_stream().to_string());
268                        });
269                    });
270                });
271            });
272
273            package.extra_items.extend(package.services.iter().flat_map(|service| {
274                let mut builder = tonic_build::CodeGenBuilder::new();
275
276                builder.emit_package(true).build_transport(true);
277
278                let make_service = |is_client: bool| {
279                    let mut builder = tonic_build::manual::Service::builder()
280                        .name(service.name())
281                        .package(&service.package);
282
283                    if !service.comments.is_empty() {
284                        builder = builder.comment(service.comments.to_string());
285                    }
286
287                    service
288                        .methods
289                        .iter()
290                        .fold(builder, |service_builder, (name, method)| {
291                            let codec_path =
292                                if let Some(Some(codec_path)) = (!is_client).then_some(method.codec_path.as_ref()) {
293                                    let path = get_common_import_path(&service.full_name, codec_path);
294                                    quote!(#path::<::tinc::reexports::tonic_prost::ProstCodec<_, _>>)
295                                } else {
296                                    quote!(::tinc::reexports::tonic_prost::ProstCodec)
297                                };
298
299                            let mut builder = tonic_build::manual::Method::builder()
300                                .input_type(
301                                    registry
302                                        .resolve_rust_path(&service.full_name, method.input.value_type().proto_path())
303                                        .unwrap()
304                                        .to_token_stream()
305                                        .to_string(),
306                                )
307                                .output_type(
308                                    registry
309                                        .resolve_rust_path(&service.full_name, method.output.value_type().proto_path())
310                                        .unwrap()
311                                        .to_token_stream()
312                                        .to_string(),
313                                )
314                                .codec_path(codec_path.to_string())
315                                .name(to_snake(name))
316                                .route_name(name);
317
318                            if method.input.is_stream() {
319                                builder = builder.client_streaming()
320                            }
321
322                            if method.output.is_stream() {
323                                builder = builder.server_streaming();
324                            }
325
326                            if !method.comments.is_empty() {
327                                builder = builder.comment(method.comments.to_string());
328                            }
329
330                            service_builder.method(builder.build())
331                        })
332                        .build()
333                };
334
335                let mut client: syn::ItemMod = syn::parse2(builder.generate_client(&make_service(true), "")).unwrap();
336                client.content.as_mut().unwrap().1.insert(
337                    0,
338                    parse_quote!(
339                        use ::tinc::reexports::tonic;
340                    ),
341                );
342
343                let mut server: syn::ItemMod = syn::parse2(builder.generate_server(&make_service(false), "")).unwrap();
344                server.content.as_mut().unwrap().1.insert(
345                    0,
346                    parse_quote!(
347                        use ::tinc::reexports::tonic;
348                    ),
349                );
350
351                [client.into(), server.into()]
352            }));
353        });
354
355        for package in packages.keys() {
356            match std::fs::remove_file(self.out_dir.join(format!("{package}.rs"))) {
357                Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("remove")),
358                _ => {}
359            }
360        }
361
362        let fds = FileDescriptorSet {
363            file: pool.file_descriptor_protos().cloned().collect(),
364        };
365
366        let fd_path = self.out_dir.join("tinc.fd.bin");
367        std::fs::write(fd_path, fds.encode_to_vec()).context("write fds")?;
368
369        config.compile_fds(fds).context("prost compile")?;
370
371        for (package, module) in &mut packages {
372            if self.extern_paths.contains(package) {
373                continue;
374            };
375
376            let path = self.out_dir.join(format!("{package}.rs"));
377            write_module(&path, std::mem::take(&mut module.extra_items)).with_context(|| package.to_owned())?;
378        }
379
380        #[derive(Default)]
381        struct Module<'a> {
382            proto_path: Option<&'a ProtoPath>,
383            children: BTreeMap<&'a str, Module<'a>>,
384        }
385
386        impl ToTokens for Module<'_> {
387            fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
388                let include = self
389                    .proto_path
390                    .map(|p| p.as_ref())
391                    .map(|path| quote!(include!(concat!(#path, ".rs"));));
392                let children = self.children.iter().map(|(part, child)| {
393                    let ident = syn::Ident::new(&to_snake(part), Span::call_site());
394                    quote! {
395                        pub mod #ident {
396                            #child
397                        }
398                    }
399                });
400                quote! {
401                    #include
402                    #(#children)*
403                }
404                .to_tokens(tokens);
405            }
406        }
407
408        if self.root_module {
409            let mut module = Module::default();
410            for package in packages.keys() {
411                let mut module = &mut module;
412                for part in package.split('.') {
413                    module = module.children.entry(part).or_default();
414                }
415                module.proto_path = Some(package);
416            }
417
418            let file: syn::File = parse_quote!(#module);
419            std::fs::write(self.out_dir.join("___root_module.rs"), prettyplease::unparse(&file))
420                .context("write root module")?;
421        }
422
423        Ok(())
424    }
425}
426
427fn write_module(path: &std::path::Path, module: Vec<syn::Item>) -> anyhow::Result<()> {
428    let mut file = match std::fs::read_to_string(path) {
429        Ok(content) if !content.is_empty() => syn::parse_file(&content).context("parse")?,
430        Err(err) if err.kind() != ErrorKind::NotFound => return Err(anyhow::anyhow!(err).context("read")),
431        _ => syn::File {
432            attrs: Vec::new(),
433            items: Vec::new(),
434            shebang: None,
435        },
436    };
437
438    file.items.extend(module);
439    std::fs::write(path, prettyplease::unparse(&file)).context("write")?;
440
441    Ok(())
442}