Learning Objectives

  • learn how a function is called in a memory
  • learn how to generate function call with parameters

How do CPU/compiler call a function?

printf

Insert function call

#include "llvm/Pass.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"

#include "llvm/Support/raw_ostream.h"

#include "llvm/IR/Module.h"

using namespace llvm;

namespace {
    struct YJ005InjectFuncCall : public ModulePass {
        static char ID;
        YJ005InjectFuncCall() : ModulePass(ID){}

        bool runOnModule(Module &M) override{
            LLVMContext & context = M.getContext();

            // creating a pointer Int8 type
            PointerType *PrintfArgTy = PointerType::getUnqual(Type::getInt8Ty(context));

            // step1: Inject declaration of printf
            FunctionType *PrintfTy = FunctionType::get(
                IntegerType::getInt32Ty(context),
                PrintfArgTy,
                true// variadic
            );

            FunctionCallee Printf = M.getOrInsertFunction("printf", PrintfTy);

            Function *PrintfF = dyn_cast<Function>(Printf.getCallee());
            PrintfF->setDoesNotThrow();
            PrintfF->addParamAttr(0, Attribute::NoCapture);
            PrintfF->addParamAttr(0, Attribute::ReadOnly);

            // step2: Inject global variable that will hold the printf format string
            Constant *PrintfFormatStr = ConstantDataArray::getString(
                context, "Function is %s, num of arg: %d\n");

            Constant *PrintfFormatStrVar = M.getOrInsertGlobal("PrintfFormatStr", PrintfFormatStr->getType());
            dyn_cast<GlobalVariable>(PrintfFormatStrVar)->setInitializer(PrintfFormatStr);

            // GlobalVariable *PrintfFormatStrVar = new GlobalVariable(M, PrintfFormatStr->getType(), true, // isConstant
            //     GlobalValue::PrivateLinkage, PrintfFormatStr, "PrintfFormatStr");

            bool modified = false;
            // step3: for each function in the module, inject a call to printf
            for(auto& F : M){
                if (!F.isDeclaration()){
                    // errs() << &*F.getEntryBlock().getFirstInsertionPt();

                    IRBuilder<> Builder(&*F.getEntryBlock().getFirstInsertionPt());

                    Value *FormatStrPtr = Builder.CreatePointerCast(PrintfFormatStrVar, PrintfArgTy, "formatStr");

                    auto funcName = Builder.CreateGlobalStringPtr(F.getName());
                    
                    // errs() << F.getName() << ' ' << Builder.getInt32(F.arg_size()) << " " << F.arg_size() << '\n';

                    Builder.CreateCall(Printf, {FormatStrPtr, funcName, Builder.getInt32(F.arg_size())});
                }


                for(auto& BB: F){
                    for(auto& I: BB){
                        // errs() << I << '\n';
                    }
                }

                modified = true;
            }

            return modified;
        }
    };
}

char YJ005InjectFuncCall::ID = 0;
static RegisterPass<YJ005InjectFuncCall> X("InjectFunc", "Inject function call", false, true);

Pass Type
Writing HelloLLVM Pass analysis
Iterating over Module, Function, Basic block analysis
Count the number of insts, func calls analysis
Insert func call transformation
Change Insts (obfuscation) transformation
Control flow graph transformation

Did you enjoy the LLVM tutorial? Now it’s time to study program analysis. here!

Reference

[1] Andrzej Warzyński. llvm-tutor. github
[2] Adrian Sampson. LLVM for Grad Students. blog
[3] Keshav Pingali. CS 380C: Advanced Topics in Compilers. blog
[4] Arthur Eubanks. The New Pass Manager. blog