//! Implement COM interfaces for Rust types. //! //! Take a look at [macro@implement] for an example. //! //! Learn more about Rust for Windows here: use quote::{quote, ToTokens}; mod r#gen; use r#gen::gen_all; #[cfg(test)] mod tests; /// Implements one or more COM interfaces. /// /// # Example /// ```rust,no_run /// use windows_core::*; /// /// #[interface("094d70d6-5202-44b8-abb8-43860da5aca2")] /// unsafe trait IValue: IUnknown { /// fn GetValue(&self, value: *mut i32) -> HRESULT; /// } /// /// #[implement(IValue)] /// struct Value(i32); /// /// impl IValue_Impl for Value_Impl { /// unsafe fn GetValue(&self, value: *mut i32) -> HRESULT { /// *value = self.0; /// HRESULT(0) /// } /// } /// /// let object: IValue = Value(123).into(); /// // Call interface methods... /// ``` #[proc_macro_attribute] pub fn implement( attributes: proc_macro::TokenStream, type_tokens: proc_macro::TokenStream, ) -> proc_macro::TokenStream { implement_core(attributes.into(), type_tokens.into()).into() } fn implement_core( attributes: proc_macro2::TokenStream, item_tokens: proc_macro2::TokenStream, ) -> proc_macro2::TokenStream { let attributes = syn::parse2::(attributes).unwrap(); let original_type = syn::parse2::(item_tokens).unwrap(); // Do a little thinking and assemble ImplementInputs. We pass ImplementInputs to // all of our gen_* function. let inputs = ImplementInputs { original_ident: original_type.ident.clone(), interface_chains: convert_implements_to_interface_chains(attributes.implement), trust_level: attributes.trust_level, impl_ident: quote::format_ident!("{}_Impl", &original_type.ident), constraints: { if let Some(where_clause) = &original_type.generics.where_clause { where_clause.predicates.to_token_stream() } else { quote!() } }, generics: if !original_type.generics.params.is_empty() { let mut params = quote! {}; original_type.generics.params.to_tokens(&mut params); quote! { <#params> } } else { quote! { <> } }, is_generic: !original_type.generics.params.is_empty(), original_type, }; let items = gen_all(&inputs); let mut tokens = inputs.original_type.into_token_stream(); for item in items { tokens.extend(item.into_token_stream()); } tokens } /// This provides the inputs to the `gen_*` functions, which generate the proc macro output. struct ImplementInputs { /// The user's type that was marked with `#[implement]`. original_type: syn::ItemStruct, /// The identifier for the user's original type definition. original_ident: syn::Ident, /// The list of interface chains that this type implements. interface_chains: Vec, /// The "trust level", which is returned by `IInspectable::GetTrustLevel`. trust_level: usize, /// The identifier of the `Foo_Impl` type. impl_ident: syn::Ident, /// The list of constraints needed for this `Foo_Impl` type. constraints: proc_macro2::TokenStream, /// The list of generic parameters for this `Foo_Impl` type, including `<` and `>`. /// If there are no generics, this contains `<>`. generics: proc_macro2::TokenStream, /// True if the user type has any generic parameters. is_generic: bool, } /// Describes one COM interface chain. struct InterfaceChain { /// The name of the field for the vtable chain, e.g. `interface4_ifoo`. field_ident: syn::Ident, /// The name of the associated constant item for the vtable chain's initializer, /// e.g. `INTERFACE4_IFOO_VTABLE`. vtable_const_ident: syn::Ident, implement: ImplementType, } struct ImplementType { type_name: String, generics: Vec, /// The best span for diagnostics. span: proc_macro2::Span, } impl ImplementType { fn to_ident(&self) -> proc_macro2::TokenStream { let type_name = syn::parse_str::(&self.type_name) .expect("Invalid token stream"); let generics = self.generics.iter().map(|g| g.to_ident()); quote! { #type_name<#(#generics,)*> } } fn to_vtbl_ident(&self) -> proc_macro2::TokenStream { let ident = self.to_ident(); quote! { <#ident as ::windows_core::Interface>::Vtable } } } #[derive(Default)] struct ImplementAttributes { pub implement: Vec, pub trust_level: usize, } impl syn::parse::Parse for ImplementAttributes { fn parse(cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result { let mut input = Self::default(); while !cursor.is_empty() { input.parse_implement(cursor)?; } Ok(input) } } impl ImplementAttributes { fn parse_implement(&mut self, cursor: syn::parse::ParseStream<'_>) -> syn::parse::Result<()> { let tree = cursor.parse::()?; self.walk_implement(&tree, &mut String::new())?; if !cursor.is_empty() { cursor.parse::()?; } Ok(()) } fn walk_implement( &mut self, tree: &UseTree2, namespace: &mut String, ) -> syn::parse::Result<()> { match tree { UseTree2::Path(input) => { if !namespace.is_empty() { namespace.push_str("::"); } namespace.push_str(&input.ident.to_string()); self.walk_implement(&input.tree, namespace)?; } UseTree2::Name(_) => { self.implement.push(tree.to_element_type(namespace)?); } UseTree2::Group(input) => { for tree in &input.items { self.walk_implement(tree, namespace)?; } } UseTree2::TrustLevel(input) => self.trust_level = *input, } Ok(()) } } enum UseTree2 { Path(UsePath2), Name(UseName2), Group(UseGroup2), TrustLevel(usize), } impl UseTree2 { fn to_element_type(&self, namespace: &mut String) -> syn::parse::Result { match self { UseTree2::Path(input) => { if !namespace.is_empty() { namespace.push_str("::"); } namespace.push_str(&input.ident.to_string()); input.tree.to_element_type(namespace) } UseTree2::Name(input) => { let mut type_name = input.ident.to_string(); let span = input.ident.span(); if !namespace.is_empty() { type_name = format!("{namespace}::{type_name}"); } let mut generics = vec![]; for g in &input.generics { generics.push(g.to_element_type(&mut String::new())?); } Ok(ImplementType { type_name, generics, span, }) } UseTree2::Group(input) => Err(syn::parse::Error::new( input.brace_token.span.join(), "Syntax not supported", )), _ => unimplemented!(), } } } struct UsePath2 { pub ident: syn::Ident, pub tree: Box, } struct UseName2 { pub ident: syn::Ident, pub generics: Vec, } struct UseGroup2 { pub brace_token: syn::token::Brace, pub items: syn::punctuated::Punctuated, } impl syn::parse::Parse for UseTree2 { fn parse(input: syn::parse::ParseStream<'_>) -> syn::parse::Result { let lookahead = input.lookahead1(); if lookahead.peek(syn::Ident) { use syn::ext::IdentExt; let ident = input.call(syn::Ident::parse_any)?; if input.peek(syn::Token![::]) { input.parse::()?; Ok(UseTree2::Path(UsePath2 { ident, tree: Box::new(input.parse()?), })) } else if input.peek(syn::Token![=]) { if ident != "TrustLevel" { return Err(syn::parse::Error::new( ident.span(), "Unrecognized key-value pair", )); } input.parse::()?; let span = input.span(); let value = input.call(syn::Ident::parse_any)?; match value.to_string().as_str() { "Partial" => Ok(UseTree2::TrustLevel(1)), "Full" => Ok(UseTree2::TrustLevel(2)), _ => Err(syn::parse::Error::new( span, "`TrustLevel` must be `Partial` or `Full`", )), } } else { let generics = if input.peek(syn::Token![<]) { input.parse::()?; let mut generics = Vec::new(); loop { generics.push(input.parse::()?); if input.parse::().is_err() { break; } } input.parse::]>()?; generics } else { Vec::new() }; Ok(UseTree2::Name(UseName2 { ident, generics })) } } else if lookahead.peek(syn::token::Brace) { let content; let brace_token = syn::braced!(content in input); let items = content.parse_terminated(UseTree2::parse, syn::Token![,])?; Ok(UseTree2::Group(UseGroup2 { brace_token, items })) } else { Err(lookahead.error()) } } } fn convert_implements_to_interface_chains(implements: Vec) -> Vec { let mut chains = Vec::with_capacity(implements.len()); for (i, implement) in implements.into_iter().enumerate() { // Create an identifier for this interface chain. // We only use this for naming fields; it is never visible to the developer. // This helps with debugging. // // We use i + 1 so that it matches the numbering of our interface offsets. Interface 0 // is the "identity" interface. let mut ident_string = format!("interface{}", i + 1); let suffix = get_interface_ident_suffix(&implement.type_name); if !suffix.is_empty() { ident_string.push('_'); ident_string.push_str(&suffix); } let field_ident = syn::Ident::new(&ident_string, implement.span); let mut vtable_const_string = ident_string.clone(); vtable_const_string.make_ascii_uppercase(); vtable_const_string.insert_str(0, "VTABLE_"); let vtable_const_ident = syn::Ident::new(&vtable_const_string, implement.span); chains.push(InterfaceChain { implement, field_ident, vtable_const_ident, }); } chains } fn get_interface_ident_suffix(type_name: &str) -> String { let mut suffix = String::new(); for c in type_name.chars() { let c = c.to_ascii_lowercase(); if suffix.len() >= 20 { break; } if c.is_ascii_alphanumeric() { suffix.push(c); } } suffix }