How to invoke a method of a private COM interfaces

2019-01-27 04:50发布

问题:

How can I invoke a method of a private COM interface, defined in a base class, from a derived class?

For example, here is the COM interface, IComInterface (IDL):

[
    uuid(9AD16CCE-7588-486C-BC56-F3161FF92EF2),
    oleautomation
]
interface IComInterface: IUnknown
{
    HRESULT ComMethod([in] IUnknown* arg);
}

Here's the C# class BaseClass from OldLibrary assembly, which implements IComInterface like this (note the interface is declared as private):

// Assembly "OldLibrary"
public static class OldLibrary
{
    [ComImport(), Guid("9AD16CCE-7588-486C-BC56-F3161FF92EF2")]
    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
    private interface IComInterface
    {
        void ComMethod([In, MarshalAs(UnmanagedType.Interface)] object arg);
    }

    [ComVisible(true)]
    [ClassInterface(ClassInterfaceType.None)]
    public class BaseClass : IComInterface
    {
        void IComInterface.ComMethod(object arg)
        {
            Console.WriteLine("BaseClass.IComInterface.ComMethod");
        }
    }
}

Finally, here's an improved version, ImprovedClass, which derives from BaseClass, but declares and implement its own version of IComInterface, because the base's OldLibrary.IComInterface is inaccessible:

// Assembly "NewLibrary"
public static class NewLibrary
{
    [ComImport(), Guid("9AD16CCE-7588-486C-BC56-F3161FF92EF2")]
    [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
    private interface IComInterface
    {
        void ComMethod([In, MarshalAs(UnmanagedType.Interface)] object arg);
    }

    [ComVisible(true)]
    [ClassInterface(ClassInterfaceType.None)]
    public class ImprovedClass : 
        OldLibrary.BaseClass, 
        IComInterface, 
        ICustomQueryInterface
    {
        // IComInterface
        void IComInterface.ComMethod(object arg)
        {
            Console.WriteLine("ImprovedClass.IComInterface.ComMethod");
            // How do I call base.ComMethod here, 
            // otherwise than via reflection?
        }

        // ICustomQueryInterface
        public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
        {
            if (iid == typeof(IComInterface).GUID)
            {
                ppv = Marshal.GetComInterfaceForObject(this, typeof(IComInterface), CustomQueryInterfaceMode.Ignore);
                return CustomQueryInterfaceResult.Handled;
            }
            ppv = IntPtr.Zero;
            return CustomQueryInterfaceResult.NotHandled;
        }   

    }
}

How do I call BaseClass.ComMethod from ImprovedClass.ComMethod without reflection?
I could use reflection, but in the real use case IComInterface is a complex OLE interface with a number of members of complex signatures.

I thought that because both BaseClass.IComInterface and ImprovedClass.IComInterface are both COM interfaces with the same GUID and identical method signatures, and there's COM Type Equivalence in .NET 4.0+, so there has to be a way to do what I'm after without reflection.

Another requirement is that ImprovedClass has to be derived from BaseClass, because the C# client code expects an instance of BaseClass, which it passes to the COM client code. Thus, containment of BaseClass inside ImprovedClass is not an option.

[EDITED] A real-life scenario which involves deriving from WebBrowser and WebBrowserSite is described here.

回答1:

I'm used to doing this in C++, so I'm mentally translating from C++ to C# here. (I.e., you may have to do some tweaking.)

COM identity rules require the set of interfaces on an object to be static. So, if you can get some interface that's definitely implemented by BaseClass, you can QI off that interface to get BaseClass'es implementation of IComInterface.

So, something like this:

type typeBaseIComInterface = typeof(OldLibrary.BaseClass).GetInterfaces().First((t) => t.GUID == typeof(IComInterface).GUID); 
IntPtr unkBaseIComInterface = Marshal.GetComInterfaceForObject(this, typeBaseIComInterface, CustomQueryInterfaceMode.Ignore);
dynamic baseptr = Marshal.GetTypedObjectForIUnknown(unkBaseIComInterface, typeof(OldLibrary.BaseClass);
baseptr.ComMethod(/* args go here */);


回答2:

Here is my solution. Ok, it uses reflection, but I don't see where is the problem since it's much simpler, and the final usage is really just one line of code, like this:

// IComInterface
void IComInterface.ComMethod(object arg)
{
    InvokeBaseMethod(this, "ComMethod", typeof(OldLibrary.BaseClass), typeof(IComInterface), arg);
}

and the utility method (reusable for any class) is this:

public static object InvokeBaseMethod(object obj, string methodName, Type baseType, Type equivalentBaseInterface, params object[] arguments)
{
    Type baseInterface = baseType.GetInterfaces().First((t) => t.GUID == equivalentBaseInterface.GUID);
    ComMemberType type = ComMemberType.Method;
    int methodSlotNumber = Marshal.GetComSlotForMethodInfo(equivalentBaseInterface.GetMethod(methodName));
    MethodInfo baseMethod = (MethodInfo)Marshal.GetMethodInfoForComSlot(baseInterface, methodSlotNumber, ref type);
    return baseMethod.Invoke(obj, arguments);
}


回答3:

I figured it out, by using a helper contained object (BaseClassComProxy) and an aggregated COM proxy object, created with Marshal.CreateAggregatedObject. This approach gives me an unmanaged object with separate identity, which I can cast (with Marshal.GetTypedObjectForIUnknown) to my own equivalent version of BaseClass.IComInterface interface, which is not otherwise accessible. It works for any other private COM interfaces, implemented by BaseClass.

@EricBrown's points about COM identity rules have helped a lot with this research. Thanks Eric!

Here's a standalone console test app. The code solving the original problem with WebBrowserSite is posted here.

using System;
using System.Diagnostics;
using System.Linq;
using System.Runtime.InteropServices;

namespace ManagedServer
{
    /*
    // IComInterface IDL definition
    [
        uuid(9AD16CCE-7588-486C-BC56-F3161FF92EF2),
        oleautomation
    ]
    interface IComInterface: IUnknown
    {
        HRESULT ComMethod(IUnknown* arg);
    }
    */

    // OldLibrary
    public static class OldLibrary
    {
        // private COM interface IComInterface
        [ComImport(), Guid("9AD16CCE-7588-486C-BC56-F3161FF92EF2")]
        [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
        private interface IComInterface
        {
            void ComMethod([In, MarshalAs(UnmanagedType.Interface)] object arg);
        }

        [ComVisible(true)]
        [ClassInterface(ClassInterfaceType.None)]
        public class BaseClass : IComInterface
        {
            void IComInterface.ComMethod(object arg)
            {
                Console.WriteLine("BaseClass.IComInterface.ComMethod");
            }
        }
    }

    // NewLibrary 
    public static class NewLibrary
    {
        // OldLibrary.IComInterface is inaccessible here,
        // define a new equivalent version
        [ComImport(), Guid("9AD16CCE-7588-486C-BC56-F3161FF92EF2")]
        [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
        private interface IComInterface
        {
            void ComMethod([In, MarshalAs(UnmanagedType.Interface)] object arg);
        }

        [ComVisible(true)]
        [ClassInterface(ClassInterfaceType.None)]
        public class ImprovedClass :
            OldLibrary.BaseClass,
            NewLibrary.IComInterface,
            ICustomQueryInterface,
            IDisposable
        {
            NewLibrary.IComInterface _baseIComInterface;
            BaseClassComProxy _baseClassComProxy;

            // IComInterface
            // we want to call BaseClass.IComInterface.ComMethod which is only accessible via COM
            void IComInterface.ComMethod(object arg)
            {
                _baseIComInterface.ComMethod(arg);
                Console.WriteLine("ImprovedClass.IComInterface.ComMethod");
            }

            // ICustomQueryInterface
            public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
            {
                if (iid == typeof(NewLibrary.IComInterface).GUID)
                {
                    // CustomQueryInterfaceMode.Ignore is to avoid infinite loop during QI.
                    ppv = Marshal.GetComInterfaceForObject(this, typeof(NewLibrary.IComInterface), CustomQueryInterfaceMode.Ignore);
                    return CustomQueryInterfaceResult.Handled;
                }
                ppv = IntPtr.Zero;
                return CustomQueryInterfaceResult.NotHandled;
            }

            // constructor
            public ImprovedClass()
            {
                // aggregate the CCW object with the helper Inner object
                _baseClassComProxy = new BaseClassComProxy(this);
                _baseIComInterface = _baseClassComProxy.GetComInterface<IComInterface>();   
            }

            ~ImprovedClass()
            {
                Dispose();
                Console.WriteLine("ImprovedClass finalized.");
            }

            // IDispose
            public void Dispose()
            {
                // we may have recicular COM references to itself
                // e.g., via _baseIComInterface
                // make sure to release all references

                if (_baseIComInterface != null)
                {
                    Marshal.ReleaseComObject(_baseIComInterface);
                    _baseIComInterface = null;
                }

                if (_baseClassComProxy != null)
                {
                    _baseClassComProxy.Dispose();
                    _baseClassComProxy = null;
                }
            }

            // for testing
            public void InvokeComMethod()
            {
                ((NewLibrary.IComInterface)this).ComMethod(null);
            }
        }

        #region BaseClassComProxy
        // Inner as aggregated object
        class BaseClassComProxy :
            ICustomQueryInterface,
            IDisposable
        {
            WeakReference _outer; // avoid circular refs between outer and inner object
            Type[] _interfaces; // the base's private COM interfaces are here
            IntPtr _unkAggregated; // aggregated proxy

            public BaseClassComProxy(object outer)
            {
                _outer = new WeakReference(outer);
                _interfaces = outer.GetType().BaseType.GetInterfaces();
                var unkOuter = Marshal.GetIUnknownForObject(outer);
                try
                {
                    // CreateAggregatedObject does AddRef on this 
                    // se we provide IDispose for proper shutdown
                    _unkAggregated = Marshal.CreateAggregatedObject(unkOuter, this); 
                }
                finally
                {
                    Marshal.Release(unkOuter);
                }
            }

            public T GetComInterface<T>() where T : class
            {
                // cast an outer's base interface to an equivalent outer's interface
                return (T)Marshal.GetTypedObjectForIUnknown(_unkAggregated, typeof(T));
            }

            public void GetComInterface<T>(out T baseInterface) where T : class
            {
                baseInterface = GetComInterface<T>();
            }

            ~BaseClassComProxy()
            {
                Dispose();
                Console.WriteLine("BaseClassComProxy object finalized.");
            }

            // IDispose
            public void Dispose()
            {
                if (_outer != null)
                {
                    _outer = null;
                    _interfaces = null;
                    if (_unkAggregated != IntPtr.Zero)
                    {
                        Marshal.Release(_unkAggregated);
                        _unkAggregated = IntPtr.Zero;
                    }
                }
            }

            // ICustomQueryInterface
            public CustomQueryInterfaceResult GetInterface(ref Guid iid, out IntPtr ppv)
            {
                // access to the outer's base private COM interfaces
                if (_outer != null)
                {
                    var ifaceGuid = iid;
                    var iface = _interfaces.FirstOrDefault((i) => i.GUID == ifaceGuid);
                    if (iface != null && iface.IsImport)
                    {
                        // must be a COM interface with ComImport attribute
                        var unk = Marshal.GetComInterfaceForObject(_outer.Target, iface, CustomQueryInterfaceMode.Ignore);
                        if (unk != IntPtr.Zero)
                        {
                            ppv = unk;
                            return CustomQueryInterfaceResult.Handled;
                        }
                    }
                }
                ppv = IntPtr.Zero;
                return CustomQueryInterfaceResult.Failed;
            }
        }
        #endregion

    }

    class Program
    {
        static void Main(string[] args)
        {
            // test
            var improved = new NewLibrary.ImprovedClass();
            improved.InvokeComMethod(); 

            //// COM client
            //var unmanagedObject = (ISimpleUnmanagedObject)Activator.CreateInstance(Type.GetTypeFromProgID("Noseratio.SimpleUnmanagedObject"));
            //unmanagedObject.InvokeComMethod(improved);

            improved.Dispose();
            improved = null;

            // test ref counting
            GC.Collect(generation: GC.MaxGeneration, mode: GCCollectionMode.Forced, blocking: false);
            Console.WriteLine("Press Enter to exit.");
            Console.ReadLine();
        }

        // COM test client interfaces
        [ComImport(), Guid("2EA68065-8890-4F69-A02F-2BC3F0418561")]
        [InterfaceType(ComInterfaceType.InterfaceIsDual)]
        internal interface ISimpleUnmanagedObject
        {
            void InvokeComMethod([In, MarshalAs(UnmanagedType.Interface)] object arg);
            void InvokeComMethodDirect([In] IntPtr comInterface);
        }

    }
}

Output:

BaseClass.IComInterface.ComMethod
ImprovedClass.IComInterface.ComMethod
Press Enter to exit.
BaseClassComProxy object finalized.
ImprovedClass finalized.


回答4:

You need to use ICustomMarshaler. I just worked out this solution and it is much less complex that what you've got, and there's no reflection. As far as I can tell, ICustomMarshaler is the only way to explicitly control that magical ability of managed objects--such as RCW proxies--where they can be cast, on-the-fly, into managed interface pointers that they don't appear to explicitly implement.

For the complete scenario I'll demonstrate, the bold-face items refer to the relevant parts of my example.

Scenario

You are receiving an unmanaged interface pointer (pUnk) into your managed code via a COM interop function (e.g. MFCreateMediaSession), perhaps previously using the excellent interop attribute ([MarshalAs(UnmanagedType.Interface)] out IMFMediaSession pSess, ... in order to receive a managed interface (IMFMediaSession). You'd like to "improve" (as you say) upon the backing __COM object that you get in this situation by providing your own managed class (session) which:

  1. perhaps adds some additional interfaces (e.g. IMFAsyncCallback);
  2. doesn't require you to forward or re-implement the interface you're already getting;
  3. consolidates the lifetime of the unmanaged interface(s) with the managed ones in a single RCW
  4. doesn't store any extraneous interface pointers...

The key is to change your marshaling directive on the function that obtains the unmanaged object so that it uses a custom marshaler. If the p/Invoke definition is in an external library you don't control, you can make your own local copy. That's what I did here, where I replaced [Out, MarshalAs(UnmanagedType.Interface)] with the new attribute:

    [DllImport("mf.dll", ExactSpelling = true), SuppressUnmanagedCodeSecurity]
    static extern HResult MFCreateMediaSession(
        [In] IMFAttributes pConfiguration,
        [Out, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(MFSessionMarshaler))] out IMFMediaSession ppMediaSession
        );

To deploy your own class that has the 'magical' interface behavior I mentioned above, you'll need two classes: an abstract base class which must be marked with [ComImport] (even if it's not, really) to provide the RCW plumbing, plus the other attributes I show (create your own GUID), and then a derived class, where you can put whatever enhanced functionality you like.

The thing to note here is that neither the base class (_session in my example) nor the derived class (session) may explicitly list the interface that you expect it to proxy from an unmanaged IUnknown. Any "proper" interface definition that duplicates a QueryInterface version will take precedence and ruin your ability to effortlessly call the unmanaged "base" methods via casting. You'll be back to COM slots and _vtbl land.

This also means that, on instances of the derived class you will only be able to access the imported interface by casting. The derived class can implement the other, "extra" interfaces in the usual way. Those can be imported COM interfaces also, by the way.

Here are the two classes I just described where your app content goes. Notice how uncluttered they are compared to if you had to forward a gigantic interface through one or more member variables (which you'd have to initialize, and clean up, etc.)

[ComImport, SuppressUnmanagedCodeSecurity, Guid("c6646f0a-3d96-4ac2-9e3f-8ae2a11145ce")]
[ClassInterface(ClassInterfaceType.None)]
public abstract class _session
{
}

public class session : _session, IMFAsyncCallback
{
    HResult IMFAsyncCallback.GetParameters(out MFASync pdwFlags, out MFAsyncCallbackQueue pdwQueue)
    {
        /// add-on interfaces can use explicit implementation...
    }

    public HResult Invoke([In, MarshalAs(UnmanagedType.Interface)] IMFAsyncResult pAsyncResult)
    {
        /// ...or public.
    }
}

Next is the ICustomMarshaler implementation. Because the argument we tagged to use this is an out argument, the managed-to-native functions of this class will never be called. The main function to implement is MarshalNativeToManaged, where I use GetTypedObjectForIUnknown specifying the derived class I defined (session). Even though that class doesn't implement IMFMediaSession, you'll be able to obtain that unmanaged interface via casting.

Calling Release in the CleanUpNativeData call is currently my best guess. (If it's wrong, I'll come back to edit this post).

class MFSessionMarshaler : ICustomMarshaler
{
    static ICustomMarshaler GetInstance(String _) => new MFSessionMarshaler();

    public Object MarshalNativeToManaged(IntPtr pUnk) => Marshal.GetTypedObjectForIUnknown(pUnk, typeof(session));

    public void CleanUpNativeData(IntPtr pNativeData) => Marshal.Release(pNativeData);

    public int GetNativeDataSize() => -1;
    IntPtr ICustomMarshaler.MarshalManagedToNative(Object _) => IntPtr.Zero;
    void ICustomMarshaler.CleanUpManagedData(Object ManagedObj) { } }

Here we see one of the few places in .NET I am aware of that you are allowed to (temporarily) violate type safety. Because notice that ppMediaSession pops out of the marshaler into your code as a full-blown, strongly-typed argument out IMFMediaSession ppMediaSession, but it certainly wasn't acting as such (i.e. without casting) immediately beforehand in the custom marshaling code.

Now you're ready to go. Here are some examples showing how you can use it, and demonstrating that things work as expected:

IMFMediaSession pI;
MFCreateMediaSession(null, out pI);  // get magical RCW

var rcw = (session)pI;   // we happen to know what it really is

pI.ClearTopologies();    // you can call IMFMediaSession members...

((IMFAsyncCallback)pI).Invoke(null);  // and also IMFAsyncCallback.
rcw.Invoke(null);        // same thing, via the backing object