visual studio warning if base method is not called

2019-07-27 21:32发布

I am researching a way to make Visual Studio fire a warning if I override a specific method in a base class but forget to call the base method in the overridden one. E.g:

class Foo
{
   [SomeAttributeToMarkTheMethodToFireTheWarning]
   public virtual void A() { ... }
}

class Bar : Foo
{
   public override void A()
   {
      // base.A(); // warning if base.A() is not called
      // ...
   }
}

So far I couldn't find a way and probably it is not possible to make the compiler fire such a warning directly. Any ideas for a way to do it, even if it's a 3rd-party tool or using some API from the new Roslyn .NET compiler platform?

UPDATE: For example, in AndroidStudio (IntelliJ) if you override onCreate() in any activity but forget to call the base method super.onCreate(), you get a warning. That's the behavior I need in VS.

2条回答
唯我独甜
2楼-- · 2019-07-27 21:57

If you want to ensure some code is run then you should change your design:

abstract class Foo
{
   protected abstract void PostA();  

   public void A() { 
      ... 
      PostA();
   }
}


class Bar : Foo
{
   protected override void PostA()
   {

   }
}

//method signature remains the same:
Bar.A();

In this way A() is always fired before your overridden method

To have multiple inheritence and to ensure A() is called you would have to make bar abstract as well:

abstract class Bar : Foo
{
   //no need to override now
}

class Baz:Bar
{
   protected override void PostA()
   {

   }
}

There is no way to do exactly what you want in C#. This isn't a Visual Studio issue. This is how C# works.

Virtual method signatures can be overridden or not, called in the base or not. You have two options virtual or abstract. Your using virtual and I've given you an abstract soltuion. It's up to you to choose which one you want to use.

The nearest thing I can think of of what you want would be a #warning. See this answer. But this will only produce the warning in the Output window not in intellisense. Basically C# does not support custom compiler warnings.

查看更多
Summer. ? 凉城
3楼-- · 2019-07-27 22:14

I finally had some time to experiment with Roslyn and looks like I found a solution with an analyzer. This is my solution.

The attribute to mark the method that needs to be overriden in the subclass:

[AttributeUsage(AttributeTargets.Method, Inherited = false, AllowMultiple = false)]
public sealed class RequireBaseMethodCallAttribute : Attribute
{
    public RequireBaseMethodCallAttribute() { }
}

The analyzer:

[DiagnosticAnalyzer(LanguageNames.CSharp)]
public class RequiredBaseMethodCallAnalyzer : DiagnosticAnalyzer
{
    public const string DiagnosticId = "RequireBaseMethodCall";

    // You can change these strings in the Resources.resx file. If you do not want your analyzer to be localize-able, you can use regular strings for Title and MessageFormat.
    // See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/Localizing%20Analyzers.md for more on localization
    private static readonly LocalizableString Title = new LocalizableResourceString(nameof(Resources.AnalyzerTitle), Resources.ResourceManager, typeof(Resources));
    private static readonly LocalizableString MessageFormat = new LocalizableResourceString(nameof(Resources.AnalyzerMessageFormat), Resources.ResourceManager, typeof(Resources));
    private static readonly LocalizableString Description = new LocalizableResourceString(nameof(Resources.AnalyzerDescription), Resources.ResourceManager, typeof(Resources));
    private const string Category = "Usage";

    private static DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Warning, isEnabledByDefault: true, description: Description);

    public override ImmutableArray<DiagnosticDescriptor> SupportedDiagnostics { get { return ImmutableArray.Create(Rule); } }

    public override void Initialize(AnalysisContext context)
    {
        context.RegisterCompilationStartAction(AnalyzeMethodForBaseCall);
    }

    private static void AnalyzeMethodForBaseCall(CompilationStartAnalysisContext compilationStartContext)
    {
        compilationStartContext.RegisterSyntaxNodeAction(AnalyzeMethodDeclaration, SyntaxKind.MethodDeclaration);
    }

    private static void AnalyzeMethodDeclaration(SyntaxNodeAnalysisContext context)
    {
        var mds = context.Node as MethodDeclarationSyntax;
        if (mds == null)
        {
            return;
        }

        IMethodSymbol symbol = context.SemanticModel.GetDeclaredSymbol(mds) as IMethodSymbol;
        if (symbol == null)
        {
            return;
        }

        if (!symbol.IsOverride)
        {
            return;
        }

        if (symbol.OverriddenMethod == null)
        {
            return;
        }

        var overridenMethod = symbol.OverriddenMethod;
        var attrs = overridenMethod.GetAttributes();
        if (!attrs.Any(ad => ad.AttributeClass.MetadataName.ToUpperInvariant() 
                            == typeof(RequireBaseMethodCallAttribute).Name.ToUpperInvariant()))
        {
            return;
        }

        var overridenMethodName = overridenMethod.Name.ToString();
        string methodName = overridenMethodName;

        var invocations = mds.DescendantNodes().OfType<MemberAccessExpressionSyntax>().ToList();
        foreach (var inv in invocations)
        {
            var expr = inv.Expression;
            if ((SyntaxKind)expr.RawKind == SyntaxKind.BaseExpression)
            {
                var memberAccessExpr = expr.Parent as MemberAccessExpressionSyntax;
                if (memberAccessExpr == null)
                {
                    continue;
                }

                // compare exprSymbol and overridenMethod
                var exprMethodName = memberAccessExpr.Name.ToString();

                if (exprMethodName != overridenMethodName)
                {
                    continue;
                }

                var invokationExpr = memberAccessExpr.Parent as InvocationExpressionSyntax;
                if (invokationExpr == null)
                {
                    continue;
                }
                var exprMethodArgs = invokationExpr.ArgumentList.Arguments.ToList();
                var ovrMethodParams = overridenMethod.Parameters.ToList();

                if (exprMethodArgs.Count != ovrMethodParams.Count)
                {
                    continue;
                }

                var paramMismatch = false;
                for (int i = 0; i < exprMethodArgs.Count; i++)
                {
                    var arg = exprMethodArgs[i];
                    var argType = context.SemanticModel.GetTypeInfo(arg.Expression);

                    var param = arg.NameColon != null ? 
                                ovrMethodParams.FirstOrDefault(p => p.Name.ToString() == arg.NameColon.Name.ToString()) : 
                                ovrMethodParams[i];

                    if (param == null || argType.Type != param.Type)
                    {
                        paramMismatch = true;
                        break;
                    }

                    exprMethodArgs.Remove(arg);
                    ovrMethodParams.Remove(param);
                    i--;
                }

                // If there are any parameters left without default value
                // then it is not the base method overload we are looking for
                if (ovrMethodParams.Any(p => p.HasExplicitDefaultValue))
                {
                    continue;
                }

                if (!paramMismatch)
                {
                    // If the actual arguments match with the method params
                    // then the base method invokation was found
                    // and there is no need to continue the search
                    return;
                }
            }
        }

        var diag = Diagnostic.Create(Rule, mds.GetLocation(), methodName);
        context.ReportDiagnostic(diag);
    }
}

The CodeFix provider:

[ExportCodeFixProvider(LanguageNames.CSharp, Name = nameof(BaseMethodCallCodeFixProvider)), Shared]
public class BaseMethodCallCodeFixProvider : CodeFixProvider
{
    private const string title = "Add base method invocation";

    public sealed override ImmutableArray<string> FixableDiagnosticIds
    {
        get { return ImmutableArray.Create(RequiredBaseMethodCallAnalyzer.DiagnosticId); }
    }

    public sealed override FixAllProvider GetFixAllProvider()
    {
        // See https://github.com/dotnet/roslyn/blob/master/docs/analyzers/FixAllProvider.md for more information on Fix All Providers
        return WellKnownFixAllProviders.BatchFixer;
    }

    public sealed override async Task RegisterCodeFixesAsync(CodeFixContext context)
    {
        var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

        var diagnostic = context.Diagnostics.First();
        var diagnosticSpan = diagnostic.Location.SourceSpan;

        // Register a code action that will invoke the fix.
        context.RegisterCodeFix(
            CodeAction.Create(
                title: title,
                createChangedDocument: c => AddBaseMethodCallAsync(context.Document, diagnosticSpan, c),
                equivalenceKey: title),
            diagnostic);
    }

    private async Task<Document> AddBaseMethodCallAsync(Document document, TextSpan diagnosticSpan, CancellationToken cancellationToken)
    {
        var root = await document.GetSyntaxRootAsync(cancellationToken);
        var node = root.FindNode(diagnosticSpan) as MethodDeclarationSyntax;

        var args = new List<ArgumentSyntax>();
        foreach (var param in node.ParameterList.Parameters)
        {
            args.Add(SyntaxFactory.Argument(SyntaxFactory.ParseExpression(param.Identifier.ValueText)));
        }

        var argsList = SyntaxFactory.SeparatedList(args);

        var exprStatement = SyntaxFactory.ExpressionStatement(
            SyntaxFactory.InvocationExpression(
                SyntaxFactory.MemberAccessExpression(
                    SyntaxKind.SimpleMemberAccessExpression,
                    SyntaxFactory.BaseExpression(),
                    SyntaxFactory.Token(SyntaxKind.DotToken),
                    SyntaxFactory.IdentifierName(node.Identifier.ToString())
                ),
                SyntaxFactory.ArgumentList(argsList)
            ),
            SyntaxFactory.Token(SyntaxKind.SemicolonToken)
        );

        var newBodyStatements = SyntaxFactory.Block(node.Body.Statements.Insert(0, exprStatement));
        var newRoot = root.ReplaceNode(node.Body, newBodyStatements).WithAdditionalAnnotations(Simplifier.Annotation);

        return document.WithSyntaxRoot(newRoot);
    }
}

And a demo how it works: http://screencast.com/t/4Jgm989TI

Since I am totally new to the .NET Compiler Platform, I would love to have any feedback and suggestions on how to improve my solution. Thank you in advance!

查看更多
登录 后发表回答