Generating Ethereum address from scratch

In this blog post I would like to introduce readers into Elliptic Curve Cryptography and use this knowledge to construct an Ethereum address.

Part 1: Elliptic Curves

An Elliptic Curve is a curve define by the equation

The curve has a number of nice properties, and one of them is that if take two points on the curve and draw a line through them, the line will intersect the curve at exactly one point. You can play around with this property using the Bokeh application below:

CODE

import numpy as np

from bokeh.layouts import column, row
from bokeh.models import CustomJS, Slider
from bokeh.plotting import ColumnDataSource, figure, show


def create_layout(a: int, b: int, start_xs=[1, 6], x_max=15, flip=False):
    roots = np.roots([1, 0, a, b])
    min_ = roots[np.isreal(roots)][0].real
    while min_**3 + min_ * a + b < 0:
        min_ += 0.0001

    step = 0.01
    x_vals = np.arange(min_, x_max, step)
    y_vals = np.sqrt(x_vals**3 + a * x_vals + b)

    start_xs = np.array(start_xs)
    start_ys = np.sqrt(start_xs**3 + a * start_xs + b)
    if flip:
        start_ys[0] = -start_ys[0]
    source = ColumnDataSource(data=dict(x=start_xs, y=start_ys))

    line_xs = [x_vals.min()] + list(start_xs) + [x_vals.max()]
    slope = (start_ys[1] - start_ys[0]) / (start_xs[1] - start_xs[0])
    y_line_min = slope * (x_vals.min() - start_xs[0]) + start_ys[0]
    y_line_max = slope * (x_vals.max() - start_xs[0]) + start_ys[0]
    line_ys = [y_line_min] + list(start_ys) + [y_line_max]

    source_line = ColumnDataSource(data=dict(x=line_xs, y=line_ys))

    p1_slider = Slider(start=x_vals.min(), end=x_vals.max(), value=start_xs[0],
                       step=step, title="Point1")
    p2_slider = Slider(start=x_vals.min(), end=x_vals.max(), value=start_xs[1],
                       step=step, title="Point2")

    plot = figure(plot_width=400, plot_height=400)
    plot.line(x_vals, y_vals, line_width=2)
    plot.line(x_vals, -y_vals, line_width=2)

    plot.circle('x', 'y', source=source, size=10, color="navy", alpha=0.5)
    plot.line('x', 'y', source=source_line, line_width=1, color="black")

    minus = "-" if flip else ""
    callback = CustomJS(args=dict(source_dots=source,
                                  source_line=source_line,
                                  p1=p1_slider, p2=p2_slider),
                        code=f"""
        const data = source_dots.data;
        const x1 = p1.value;
        const x2 = p2.value;
        const x = data['x'];
        const y = data['y'];
        x[0] = x1;
        x[1] = x2;
        const y1 = {minus}Math.sqrt(Math.pow(x1, 3) + {a} * x1 + {b});
        const y2 = Math.sqrt(Math.pow(x2, 3) + {a} * x2 + {b});
        y[0] = y1;
        y[1] = y2;
        source_line.data['x'][1] = x1;
        source_line.data['x'][2] = x2;
        source_line.data['y'][1] = y1;
        source_line.data['y'][2] = y2;
        const slope = (y2 - y1) / (x2 - x1);
        source_line.data['y'][0] = slope * (source_line.data['x'][0] - x1) + y1;
        source_line.data['y'][3] = slope * (source_line.data['x'][3] - x1) + y1;
        source_dots.change.emit();
        source_line.change.emit();
    """)

    p1_slider.js_on_change('value', callback)
    p2_slider.js_on_change('value', callback)

    layout = column(
        plot,
        column(p1_slider, p2_slider),
    )

    return layout


l1 = create_layout(2, 29, flip=True)
l2 = create_layout(2, 29, flip=False)

show(row(l1, l2))

This property allows us to map 2 points on a curve to a 3 point, an operation that we will refer to as an addition (with respect to the curve). The reason we call this an addition, because it shares certain similarities with an arithmetic addition. However, we refine our addition to result in the reflected point, see illustration below:

Algebraically, this means that

where slope, \(s\), defined as

when \(P \neq Q\), and

when \(P = Q\). Note that in this case the slope is just a tangent and is easy to derive:

We can verify algebraically that \(y_R^2 = x_R^3 + ax_R + b\), meaning \(R\) is on our curve. Finally, for our addition to be valid, we need to define a point that would serve the role of \(0\). So, by definition, \(P + O = P\), and \(-P + P = O\). My vague intuition is that I can think of \(O\) as the entire x-axis.

Part 2: Mathematical Fields

If we pick a prime number \(p\), we can define multiplication, addition, subtraction and division module \(p\):

class Field:
    def __init__(self, p: int):
            self.p = p

    def multiply(self, a: int, b: int) -> int:
        return a * b % self.p

    def add(self, a: int, b: int) -> int:
        return (a + b) % self.p

    def sub(self, a: int, b: int) -> int:
        return (a - b) % self.p

    def div(self, a: int, b: int) -> int:
        b = b % self.p
        if b == 0:
            raise ZeroDivisionError
        inv_b = egcd(b, self.p)[0] % self.p # Euclidean algorithm
        return self.multiply(a, inv_b)

Mathematically, this means that we have a field, \(\mathbb{F}_p\). Real numbers form a field too, so you probably have a good intuition for how fields work.

Here are arithmetic tables for \(\mathbb{F}_{7}\), skipping subtraction as I don’t have space:

*1 2 3 4 5 6
11 2 3 4 5 6
22 4 6 1 3 5
33 6 2 5 1 4
44 1 5 2 6 3
55 3 1 6 4 2
66 5 4 3 2 1
/1 2 3 4 5 6
11 2 3 4 5 6
24 1 5 2 6 3
35 3 1 6 4 2
42 4 6 1 3 5
53 6 2 5 1 4
66 5 4 3 2 1
+1 2 3 4 5 6
12 3 4 5 6 0
23 4 5 6 0 1
34 5 6 0 1 2
45 6 0 1 2 3
56 0 1 2 3 4
60 1 2 3 4 5

Plot twist: we are now going to deal with an elliptic curve defined over \(\mathbb{F}_p\). Turns out the equations for addition with respect to a curve, \eqref{eqn:point_r}, are still well-defined: they only use simple arithmetic operations (*, /, +, -) anyway. Unfortunately, we no longer have the geometric intuition, as our curve plotted will now look like just a bunch of points.

Part 3: Toy Elliptic Curve over \(\mathbb{F}_{37}\)

I wrote some auxiliary code that implements fields and elliptic curves, attached in appendix. For now, let’s use play with it to get used to the elliptic curves over prime fields:

def generate_span(start_point):  # keep adding until we hit zero
    span = [start_point]
    while not span[-1].zero:
        span.append(span[-1] + start_point)
    return span

ec = EC(11, 16) # initiate y^2 = x^3 + 11x + 16
field_37 = Field(37)
point1 = ec.point(field_37.n(0), None) # y=None will find a point on the curve
span1 = generate_span(point1)

> span1
[[0_₃₇, 4]_₁₁,₁₆,
 [36_₃₇, 2]_₁₁,₁₆,
 [5_₃₇, 23]_₁₁,₁₆,
 [5_₃₇, 14]_₁₁,₁₆,
 [36_₃₇, 35]_₁₁,₁₆,
 [0_₃₇, 33]_₁₁,₁₆,
 [None, None]_₁₁,₁₆] # zero

I started with a point \((0, 4)\), module 37, and kept adding it until I reached \(O\). I have to reach \(O\), because the number of points on the curve is finite, so they have to start repeating:

for some \(n \neq m\). Are there any more points on the curve? Let’s find out:

def find_outside_span(point):
    span = generate_span(point)
    for i in range(point.x.field.p):
        try:
            p = ec.point(point.x.field.n(i), None)
        except ValueError:
            continue

        if p not in span:
            return p

point2 = find_outside_span(point1)
>point2
[1_₃₇, 18]_₁₁,₁₆

Yes, there’s at least one more point: \((1, 18)\). Any more?

> find_outside_span(point2)
None

Looks like \((1, 18)\) generates all points on our curve. How many points do we have? 49:

> len(generate_span(point2)), len(span1)
49, 7

In mathematics , we say that the order of point2 is 49, whilst the order of point1 is 7, where order is the number of times we add until we reach zero. point2 in this toy example plays the same role that number \(1\) plays in integers: by adding \(1\) repeatedly we an generate all positive integers. Same holds for point2: by adding it repeatedly we generate all points on the curve. OK, hopefully you get a better feel for this object. Let’s proceed.

Part 4: Ethereum Address

I went on vanity-eth.tk and generated a public and a private key:

0xeace72e038dd840a468986654448a7d6d3e1a750 # public
0xf6096531d756e52953455522cacc2a1effe504577233df66d3a3110231ff918c # private

We have already covered all the material needed to replicate such computation. Both Ethereum and Bitcoin use an elliptic curve \(y^2 = x^3 + 7\) over \(\mathbb{F}_p\), where \(p = 2^{256} - 2^{32} - 2^{9} - 2^{8} - 2^{7} - 2^{6} - 2^{4} - 1\), a very large prime.

It turns out whilst calculating \(n * P = Q\), i.e. summing \(P\) with itself \(n\)-times, is computationally easy, finding \(n\), given \(P\) and \(Q\) is computationally hard. So the private key is \(n\), whilst the public key is \(n * P\), for some specified \(P\). Below is the code that defines the curve and the specified point, taken from the secp256k1 specs:

p = 2**256 - 2**32 - 2**9 - 2**8 - 2**7 - 2**6 - 2**4 - 1
# hex codes for compactness, those are just large integers
x = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
y = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
point = EC(0, 7).point(Field(p).n(x), y)

Finally, we can generate the public point using multiplication:

pp = 0xf6096531d756e52953455522cacc2a1effe504577233df66d3a3110231ff918c * point

That’s almost our public key, but not quite. We now encode (x, y) coordinates of our public point in hex, to get something like

def public(self) -> str:
    if self.x is None or self.y is None:
	raise ValueError("Zero can't be public")

    x = self.x.a if isinstance(self.x, Field.Int) else self.x
    y = self.y.a if isinstance(self.y, Field.Int) else self.y
    return f"{x:064x}{y:064x}"

> pp.public()
c662ecf3e1c45eefbacd6244863dff8a23e414fc7f52567dcaf1796c898730ab8392
c6811de16d4c235a7522b83cba9ed1824af40ac72932863689a342215033

Finally, we compute Keccak-256 hash function on the public point, and take last 40 characters:

> pp.ethereum()
0xeace72e038dd840a468986654448a7d6d3e1a750

Voilà!

Credits

Thanks to bellbind github account for his toy implementations and Timur Badretdinov for his blog post.

Appendix

ALL THE CODE

from __future__ import annotations
from typing import Optional, Union, Tuple
import math
import codecs


def to_subscript(n):
    SUB = str.maketrans("0123456789", "₀₁₂₃₄₅₆₇₈₉")
    return str(n).translate(SUB)


def egcd(a, b) -> Tuple[int, int, int]:
    """extended GCD
    returns: (s, t, gcd) as a*s + b*t == gcd
    >>> s, t, gcd = egcd(a, b)
    >>> assert a % gcd == 0 and b % gcd == 0
    >>> assert a * s + b * t == gcd
    """
    s0, s1, t0, t1 = 1, 0, 0, 1
    while b > 0:
        q, r = divmod(a, b)
        a, b = b, r
        s0, s1, t0, t1 = s1, s0 - q * s1, t1, t0 - q * t1
    return s0, t0, a


class Field:
    class Int:
        def to_int(self, other: Union[Field.Int, int]) -> int:
            if isinstance(other, int):
                return other
            elif hasattr(other, 'a') and hasattr(other, 'field'):
                assert self.field.p == other.field.p
                return other.a
            else:
                raise NotImplementedError

        def __init__(self, field: Field, a: int):
            self.field, self.a = field, a

        def __str__(self) -> str:
            return str(self.a)

        def __repr__(self) -> str:
            return f"{self.a}_{to_subscript(self.field.p)}"

        def __mul__(self, other: Union[Field.Int, int]) -> Field.Int:
            return self.field.n(self.field.multiply(self.a, self.to_int(other)))

        def __rmul__(self, other: int) -> Field.Int:
            return self.__mul__(other)

        def __add__(self, other: Union[Field.Int, int]) -> Field.Int:
            return self.field.n(self.field.add(self.a, self.to_int(other)))

        def __radd__(self, other: int) -> Field.Int:
            return self.__add__(other)

        def __sub__(self, other: Union[Field.Int, int]) -> Field.Int:
            return self.field.n(self.field.sub(self.a, self.to_int(other)))

        def __rsub__(self, other: int) -> Field.Int:
            return self.field.n(self.field.sub(self.to_int(other), self.a))

        def __neg__(self) -> Field.Int:
            return self.field.n(self.field.sub(0, self.a))

        def __truediv__(self, other: Union[Field.Int, int]) -> Field.Int:
            return self.field.n(self.field.div(self.a, self.to_int(other)))

        def __rtruediv__(self, other: int) -> Field.Int:
            return self.field.n(self.field.div(self.to_int(other), self.a))

        def __pow__(self, other: int) -> Field.Int:
            return self.field.n(self.field.pow(self.a, other))

        def __eq__(self, other: object):
            if isinstance(other, (Field.Int, int)):
                return self.field.eq(self.a, self.to_int(other))
            return False

    def __init__(self, p: int):
        self.p = p

    def n(self, a: int) -> Field.Int:
        return Field.Int(self, a % self.p)

    def multiply(self, a: int, b: int) -> int:
        return a * b % self.p

    def add(self, a: int, b: int) -> int:
        return (a + b) % self.p

    def sub(self, a: int, b: int) -> int:
        return (a - b) % self.p

    def div(self, a: int, b: int) -> int:
        b = b % self.p
        if b == 0:
            raise ZeroDivisionError
        inv_b = egcd(b, self.p)[0] % self.p
        return self.multiply(a, inv_b)

    def pow(self, a: int, n: int) -> int:
        if n < 0:
            return self.pow(self.div(1, a), -n)
        r, m2 = 1, a
        while n > 0:
            if n & 1 == 1:
                r = self.multiply(r, m2)
            n, m2 = n >> 1, self.multiply(m2, m2)
        return r

    def eq(self, a: int, b: int) -> bool:
        return (a - b) % self.p == 0

    def sqrt(self, a: int) -> int:
        a = a % self.p
        for i in range(self.p):
            if (i * i) % self.p == a:
                return i
        raise ValueError("not found")


class EC:
    class Point:
        def __init__(self, ec: EC,
                     x: Optional[Union[Field.Int, float]] = None,
                     y: Optional[Union[Field.Int, float]] = None):
            self.zero = x is None and y is None
            if x is not None and y is None:
                y = ec.at(x)
            self.ec, self.x, self.y = ec, x, y

        def __eq__(self, other) -> bool:
            return self.x == other.x and self.y == other.y

        def __neg__(self) -> EC.Point:
            if self.x is not None and self.y is not None:
                return self.ec.point(self.x, -self.y)

            return self.ec.point(x=None, y=None)

        def __add__(self, other: EC.Point) -> EC.Point:
            if self.x is None or self.y is None:
                return other
            if other.x is None or other.y is None:
                return self
            if self == -other:
                return self.ec.point(None, None)
            if self == other:
                slope = (3 * (self.x)**2 + self.ec.a) / (2 * self.y)
            else:
                slope = (other.y - self.y) / (other.x - self.x)

            x = slope**2 - self.x - other.x
            y = slope * (self.x - x) - self.y

            return self.ec.point(x, y)

        def __mul__(self, n: int) -> EC.Point:
            r, m2 = self.ec.point(None, None), self
            # O(log2(n)) add
            while n > 0:
                if n & 1 == 1:
                    r += m2
                n, m2 = n >> 1, m2 + m2
            return r

        def __rmul__(self, n: int) -> EC.Point:
            return self.__mul__(n)

        def __str__(self):
            return f"[{self.x}, {self.y}]"

        def __repr__(self):
            t = to_subscript
            return f"[{repr(self.x)}, {str(self.y)}]_{t(self.ec.a)},{t(self.ec.b)}"

        def public(self) -> str:
            if self.x is None or self.y is None:
                raise ValueError("Zero can't be public")

            x = self.x.a if isinstance(self.x, Field.Int) else self.x
            y = self.y.a if isinstance(self.y, Field.Int) else self.y
            return f"{x:064x}{y:064x}"

        def ethereum(self) -> str:
            # pip install pycryptodomex
            from Cryptodome.Hash import keccak
            public_key_bytes = codecs.decode(self.public(), 'hex')
            keccak_hash = keccak.new(digest_bits=256)
            keccak_hash.update(public_key_bytes)
            keccak_digest = keccak_hash.hexdigest()
            # Take the last 20 bytes
            wallet_len = 40
            wallet = '0x' + keccak_digest[-wallet_len:]
            return wallet

    def __init__(self, a: Union[Field.Int, float], b: Union[Field.Int, float]):
        self.a, self.b = a, b  # y^2 = (x^3 + ax + b) % p

    def point(self, x: Optional[Union[Field.Int, float]], y: Optional[Union[Field.Int, float]]):
        return EC.Point(self, x, y)

    def y_squared(self, x: Union[Field.Int, float]) -> Union[Field.Int, float]:
        return x**3 + self.a * x + self.b

    def at(self, x: Union[Field.Int, float]) -> Union[Field.Int, float]:
        if isinstance(x, (int, float)):
            return math.sqrt(self.y_squared(x))
        if isinstance(x, Field.Int):
            return x.field.sqrt(self.y_squared(x.a))
        raise NotImplementedError