tinc_build/codegen/service/
mod.rs

1use anyhow::Context;
2use indexmap::IndexMap;
3use openapi::{BodyMethod, GeneratedBody, GeneratedParams, InputGenerator, OutputGenerator};
4use openapiv3_1::HttpMethod;
5use quote::{format_ident, quote};
6use syn::{Ident, parse_quote};
7use tinc_pb_prost::http_endpoint_options;
8
9use super::Package;
10use super::utils::{field_ident_from_str, type_ident_from_str};
11use crate::types::{
12    Comments, ProtoPath, ProtoService, ProtoServiceMethod, ProtoServiceMethodEndpoint, ProtoServiceMethodIo,
13    ProtoTypeRegistry, ProtoValueType,
14};
15
16mod openapi;
17
18struct GeneratedMethod {
19    function_body: proc_macro2::TokenStream,
20    openapi: openapiv3_1::path::PathItem,
21    http_method: Ident,
22    path: String,
23}
24
25impl GeneratedMethod {
26    #[allow(clippy::too_many_arguments)]
27    fn new(
28        name: &str,
29        package: &str,
30        service_name: &str,
31        service: &ProtoService,
32        method: &ProtoServiceMethod,
33        endpoint: &ProtoServiceMethodEndpoint,
34        types: &ProtoTypeRegistry,
35        components: &mut openapiv3_1::Components,
36    ) -> anyhow::Result<GeneratedMethod> {
37        let (http_method_oa, path) = match &endpoint.method {
38            tinc_pb_prost::http_endpoint_options::Method::Get(path) => (openapiv3_1::HttpMethod::Get, path),
39            tinc_pb_prost::http_endpoint_options::Method::Post(path) => (openapiv3_1::HttpMethod::Post, path),
40            tinc_pb_prost::http_endpoint_options::Method::Put(path) => (openapiv3_1::HttpMethod::Put, path),
41            tinc_pb_prost::http_endpoint_options::Method::Delete(path) => (openapiv3_1::HttpMethod::Delete, path),
42            tinc_pb_prost::http_endpoint_options::Method::Patch(path) => (openapiv3_1::HttpMethod::Patch, path),
43        };
44
45        let trimmed_path = path.trim_start_matches('/');
46        let full_path = if let Some(prefix) = &service.options.prefix {
47            format!("/{}/{}", prefix.trim_end_matches('/'), trimmed_path)
48        } else {
49            format!("/{trimmed_path}")
50        };
51
52        let http_method = quote::format_ident!("{http_method_oa}");
53        let tracker_ident = quote::format_ident!("tracker");
54        let target_ident = quote::format_ident!("target");
55        let state_ident = quote::format_ident!("state");
56        let mut openapi = openapiv3_1::path::Operation::new();
57        let mut generator = InputGenerator::new(
58            types,
59            components,
60            package,
61            method.input.value_type().clone(),
62            tracker_ident.clone(),
63            target_ident.clone(),
64            state_ident.clone(),
65        );
66
67        openapi.tag(service_name);
68
69        let GeneratedParams {
70            tokens: path_tokens,
71            params,
72        } = generator.generate_path_parameter(&full_path)?;
73        openapi.parameters(params);
74
75        let is_get_or_delete = matches!(http_method_oa, HttpMethod::Get | HttpMethod::Delete);
76        let request = endpoint.request.as_ref().and_then(|req| req.mode.clone()).unwrap_or_else(|| {
77            if is_get_or_delete {
78                http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams::default())
79            } else {
80                http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody::default())
81            }
82        });
83
84        let request_tokens = match request {
85            http_endpoint_options::request::Mode::Query(http_endpoint_options::request::QueryParams { field }) => {
86                let GeneratedParams { tokens, params } = generator.generate_query_parameter(field.as_deref())?;
87                openapi.parameters(params);
88                tokens
89            }
90            http_endpoint_options::request::Mode::Binary(http_endpoint_options::request::BinaryBody {
91                field,
92                content_type_accepts,
93                content_type_field,
94            }) => {
95                let GeneratedBody { tokens, body } = generator.generate_body(
96                    &method.cel,
97                    BodyMethod::Binary(content_type_accepts.as_deref()),
98                    field.as_deref(),
99                    content_type_field.as_deref(),
100                )?;
101                openapi.request_body = Some(body);
102                tokens
103            }
104            http_endpoint_options::request::Mode::Json(http_endpoint_options::request::JsonBody { field }) => {
105                let GeneratedBody { tokens, body } =
106                    generator.generate_body(&method.cel, BodyMethod::Json, field.as_deref(), None)?;
107                openapi.request_body = Some(body);
108                tokens
109            }
110            http_endpoint_options::request::Mode::Text(http_endpoint_options::request::TextBody { field }) => {
111                let GeneratedBody { tokens, body } =
112                    generator.generate_body(&method.cel, BodyMethod::Text, field.as_deref(), None)?;
113                openapi.request_body = Some(body);
114                tokens
115            }
116        };
117
118        let input_path = match &method.input {
119            ProtoServiceMethodIo::Single(input) => types.resolve_rust_path(package, input.proto_path()),
120            ProtoServiceMethodIo::Stream(_) => anyhow::bail!("currently streaming is not supported by tinc methods."),
121        };
122
123        let service_method_name = field_ident_from_str(name);
124
125        let response = endpoint
126            .response
127            .as_ref()
128            .and_then(|resp| resp.mode.clone())
129            .unwrap_or_else(
130                || http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json::default()),
131            );
132
133        let response_ident = quote::format_ident!("response");
134        let builder_ident = quote::format_ident!("builder");
135        let mut generator = OutputGenerator::new(
136            types,
137            components,
138            method.output.value_type().clone(),
139            response_ident.clone(),
140            builder_ident.clone(),
141        );
142
143        let GeneratedBody {
144            body: response,
145            tokens: response_tokens,
146        } = match response {
147            http_endpoint_options::response::Mode::Binary(http_endpoint_options::response::Binary {
148                field,
149                content_type_accepts,
150                content_type_field,
151            }) => generator.generate_body(
152                BodyMethod::Binary(content_type_accepts.as_deref()),
153                field.as_deref(),
154                content_type_field.as_deref(),
155            )?,
156            http_endpoint_options::response::Mode::Json(http_endpoint_options::response::Json { field }) => {
157                generator.generate_body(BodyMethod::Json, field.as_deref(), None)?
158            }
159            http_endpoint_options::response::Mode::Text(http_endpoint_options::response::Text { field }) => {
160                generator.generate_body(BodyMethod::Text, field.as_deref(), None)?
161            }
162        };
163
164        openapi.response("200", response);
165
166        let validate = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
167            quote! {
168                if let Err(err) = ::tinc::__private::TincValidate::validate_http(&#target_ident, #state_ident, &#tracker_ident) {
169                    return err;
170                }
171            }
172        } else {
173            quote!()
174        };
175
176        let function_impl = quote! {
177            let mut #state_ident = ::tinc::__private::TrackerSharedState::default();
178            let mut #tracker_ident = <<#input_path as ::tinc::__private::TrackerFor>::Tracker as ::core::default::Default>::default();
179            let mut #target_ident = <#input_path as ::core::default::Default>::default();
180
181            #path_tokens
182            #request_tokens
183
184            #validate
185
186            let request = ::tinc::reexports::tonic::Request::from_parts(
187                ::tinc::reexports::tonic::metadata::MetadataMap::from_headers(parts.headers),
188                parts.extensions,
189                target,
190            );
191
192            let (metadata, #response_ident, extensions) = match service.inner.#service_method_name(request).await {
193                ::core::result::Result::Ok(response) => response.into_parts(),
194                ::core::result::Result::Err(status) => return ::tinc::__private::handle_tonic_status(&status),
195            };
196
197            let mut response = {
198                let mut #builder_ident = ::tinc::reexports::http::Response::builder();
199                match #response_tokens {
200                    ::core::result::Result::Ok(v) => v,
201                    ::core::result::Result::Err(err) => return ::tinc::__private::handle_response_build_error(err),
202                }
203            };
204
205            response.headers_mut().extend(metadata.into_headers());
206            *response.extensions_mut() = extensions;
207
208            response
209        };
210
211        Ok(GeneratedMethod {
212            function_body: function_impl,
213            http_method,
214            openapi: openapiv3_1::PathItem::new(http_method_oa, openapi),
215            path: full_path,
216        })
217    }
218
219    pub(crate) fn method_handler(
220        &self,
221        function_name: &Ident,
222        server_module_name: &Ident,
223        service_trait: &Ident,
224        tinc_struct_name: &Ident,
225    ) -> proc_macro2::TokenStream {
226        let function_impl = &self.function_body;
227
228        quote! {
229            #[allow(non_snake_case, unused_mut, dead_code, unused_variables, unused_parens)]
230            async fn #function_name<T>(
231                ::tinc::reexports::axum::extract::State(service): ::tinc::reexports::axum::extract::State<#tinc_struct_name<T>>,
232                request: ::tinc::reexports::axum::extract::Request,
233            ) -> ::tinc::reexports::axum::response::Response
234            where
235                T: super::#server_module_name::#service_trait,
236            {
237                let (mut parts, body) = ::tinc::reexports::axum::RequestExt::with_limited_body(request).into_parts();
238                #function_impl
239            }
240        }
241    }
242
243    pub(crate) fn route(&self, function_name: &Ident) -> proc_macro2::TokenStream {
244        let path = &self.path;
245        let http_method = &self.http_method;
246
247        quote! {
248            .route(#path, ::tinc::reexports::axum::routing::#http_method(#function_name::<T>))
249        }
250    }
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub(crate) struct ProcessedService {
255    pub full_name: ProtoPath,
256    pub package: ProtoPath,
257    pub comments: Comments,
258    pub openapi: openapiv3_1::OpenApi,
259    pub methods: IndexMap<String, ProcessedServiceMethod>,
260}
261
262impl ProcessedService {
263    pub(crate) fn name(&self) -> &str {
264        self.full_name
265            .strip_prefix(&*self.package)
266            .unwrap_or(&self.full_name)
267            .trim_matches('.')
268    }
269}
270
271#[derive(Debug, Clone, PartialEq)]
272pub(crate) struct ProcessedServiceMethod {
273    pub codec_path: Option<ProtoPath>,
274    pub input: ProtoServiceMethodIo,
275    pub output: ProtoServiceMethodIo,
276    pub comments: Comments,
277}
278
279pub(super) fn handle_service(
280    service: &ProtoService,
281    package: &mut Package,
282    registry: &ProtoTypeRegistry,
283) -> anyhow::Result<()> {
284    let name = service
285        .full_name
286        .strip_prefix(&*service.package)
287        .and_then(|s| s.strip_prefix('.'))
288        .unwrap_or(&*service.full_name);
289
290    let mut components = openapiv3_1::Components::new();
291    let mut paths = openapiv3_1::Paths::builder();
292
293    let snake_name = field_ident_from_str(name);
294    let pascal_name = type_ident_from_str(name);
295
296    let tinc_module_name = quote::format_ident!("{snake_name}_tinc");
297    let server_module_name = quote::format_ident!("{snake_name}_server");
298    let tinc_struct_name = quote::format_ident!("{pascal_name}Tinc");
299
300    let mut method_tokens = Vec::new();
301    let mut route_tokens = Vec::new();
302    let mut method_codecs = Vec::new();
303    let mut methods = IndexMap::new();
304
305    let package_name = format!("{}.{tinc_module_name}", service.package);
306
307    for (method_name, method) in service.methods.iter() {
308        for (idx, endpoint) in method.endpoints.iter().enumerate() {
309            let gen_method = GeneratedMethod::new(
310                method_name,
311                &package_name,
312                name,
313                service,
314                method,
315                endpoint,
316                registry,
317                &mut components,
318            )?;
319            let function_name = quote::format_ident!("{method_name}_{idx}");
320
321            method_tokens.push(gen_method.method_handler(
322                &function_name,
323                &server_module_name,
324                &pascal_name,
325                &tinc_struct_name,
326            ));
327            route_tokens.push(gen_method.route(&function_name));
328            paths = paths.path(gen_method.path, gen_method.openapi);
329        }
330
331        let codec_path = if matches!(method.input.value_type(), ProtoValueType::Message(_)) {
332            let input_path = registry.resolve_rust_path(&package_name, method.input.value_type().proto_path());
333            let output_path = registry.resolve_rust_path(&package_name, method.output.value_type().proto_path());
334            let codec_ident = format_ident!("{method_name}Codec");
335            method_codecs.push(quote! {
336                #[derive(Debug, Clone, Default)]
337                #[doc(hidden)]
338                pub struct #codec_ident<C>(C);
339
340                #[allow(clippy::all, dead_code, unused_imports, unused_variables, unused_parens)]
341                const _: () = {
342                    #[derive(Debug, Clone, Default)]
343                    pub struct Encoder<E>(E);
344                    #[derive(Debug, Clone, Default)]
345                    pub struct Decoder<D>(D);
346
347                    impl<C> ::tinc::reexports::tonic::codec::Codec for #codec_ident<C>
348                    where
349                        C: ::tinc::reexports::tonic::codec::Codec<Encode = #output_path, Decode = #input_path>
350                    {
351                        type Encode = C::Encode;
352                        type Decode = C::Decode;
353
354                        type Encoder = C::Encoder;
355                        type Decoder = Decoder<C::Decoder>;
356
357                        fn encoder(&mut self) -> Self::Encoder {
358                            ::tinc::reexports::tonic::codec::Codec::encoder(&mut self.0)
359                        }
360
361                        fn decoder(&mut self) -> Self::Decoder {
362                            Decoder(
363                                ::tinc::reexports::tonic::codec::Codec::decoder(&mut self.0)
364                            )
365                        }
366                    }
367
368                    impl<D> ::tinc::reexports::tonic::codec::Decoder for Decoder<D>
369                    where
370                        D: ::tinc::reexports::tonic::codec::Decoder<Item = #input_path, Error = ::tinc::reexports::tonic::Status>
371                    {
372                        type Item = D::Item;
373                        type Error = ::tinc::reexports::tonic::Status;
374
375                        fn decode(&mut self, buf: &mut ::tinc::reexports::tonic::codec::DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
376                            match ::tinc::reexports::tonic::codec::Decoder::decode(&mut self.0, buf) {
377                                ::core::result::Result::Ok(::core::option::Option::Some(item)) => {
378                                    ::tinc::__private::TincValidate::validate_tonic(&item)?;
379                                    ::core::result::Result::Ok(::core::option::Option::Some(item))
380                                },
381                                ::core::result::Result::Ok(::core::option::Option::None) => ::core::result::Result::Ok(::core::option::Option::None),
382                                ::core::result::Result::Err(err) => ::core::result::Result::Err(err),
383                            }
384                        }
385
386                        fn buffer_settings(&self) -> ::tinc::reexports::tonic::codec::BufferSettings {
387                            ::tinc::reexports::tonic::codec::Decoder::buffer_settings(&self.0)
388                        }
389                    }
390                };
391            });
392            Some(ProtoPath::new(format!("{package_name}.{codec_ident}")))
393        } else {
394            None
395        };
396
397        methods.insert(
398            method_name.clone(),
399            ProcessedServiceMethod {
400                codec_path,
401                input: method.input.clone(),
402                output: method.output.clone(),
403                comments: method.comments.clone(),
404            },
405        );
406    }
407
408    let openapi_tag = openapiv3_1::Tag::builder()
409        .name(name)
410        .description(service.comments.to_string())
411        .build();
412    let openapi = openapiv3_1::OpenApi::builder()
413        .components(components)
414        .paths(paths)
415        .tags(vec![openapi_tag])
416        .build();
417
418    let json_openapi = openapi.to_json().context("invalid openapi schema generation")?;
419
420    package.push_item(parse_quote! {
421        /// This module was automatically generated by `tinc`.
422        #[allow(clippy::all)]
423        pub mod #tinc_module_name {
424            #![allow(
425                unused_variables,
426                dead_code,
427                missing_docs,
428                clippy::wildcard_imports,
429                clippy::let_unit_value,
430                unused_parens,
431                irrefutable_let_patterns,
432            )]
433
434            /// A tinc service struct that exports gRPC routes via an axum router.
435            pub struct #tinc_struct_name<T> {
436                inner: ::std::sync::Arc<T>,
437            }
438
439            impl<T> #tinc_struct_name<T> {
440                /// Create a new tinc service struct from a service implementation.
441                pub fn new(inner: T) -> Self {
442                    Self { inner: ::std::sync::Arc::new(inner) }
443                }
444
445                /// Create a new tinc service struct from an existing `Arc`.
446                pub fn from_arc(inner: ::std::sync::Arc<T>) -> Self {
447                    Self { inner }
448                }
449            }
450
451            impl<T> ::std::clone::Clone for #tinc_struct_name<T> {
452                fn clone(&self) -> Self {
453                    Self { inner: ::std::clone::Clone::clone(&self.inner) }
454                }
455            }
456
457            impl<T> ::std::fmt::Debug for #tinc_struct_name<T> {
458                fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
459                    write!(f, stringify!(#tinc_struct_name))
460                }
461            }
462
463            impl<T> ::tinc::TincService for #tinc_struct_name<T>
464            where
465                T: super::#server_module_name::#pascal_name
466            {
467                fn into_router(self) -> ::tinc::reexports::axum::Router {
468                    #(#method_tokens)*
469
470                    ::tinc::reexports::axum::Router::new()
471                        #(#route_tokens)*
472                        .with_state(self)
473                }
474
475                fn openapi_schema_str(&self) -> &'static str {
476                    #json_openapi
477                }
478            }
479
480            #(#method_codecs)*
481        }
482    });
483
484    package.services.push(ProcessedService {
485        full_name: service.full_name.clone(),
486        package: service.package.clone(),
487        comments: service.comments.clone(),
488        openapi,
489        methods,
490    });
491
492    Ok(())
493}