UP | HOME
Diego Rodriguez Mancini

Diego Rodriguez Mancini

Software Engineer

Applying Dependency Injection in a Flask Application
Published on Apr 09, 2022 by Diego Rodriguez Mancini.

Introduction

This is a tutorial / example of applying dependency injection in a Flask application, using the Dependency Injector library.

Flask is a very minimal framework. It allows you to quickly create HTTP services without forcing you to follow a defined structure. This can be good and bad, depending on what you are building and what you want to achieve. With so much freedom you might end up creating an application that has no structure at all. On the other hand, you can try to apply all the tools that the framework provides which is hard to do when you are not familiar with Flask.

It is common to start with a small app that has a few endpoints with not too many features. As your application evolves you add more and more functionality to the point where you have hundreds of lines of code in a single route function.

When you start breaking your code into separate modules you might still maintain a big problem. Your code is separated into different files or classes, but is still tightly coupled.

Database connections are still hard coded into your now “modularized” code. With any new class you create, you are instantiating it somewhere, making that part of the code hard to change due to the restriction of the dependencies.

Testing code that has tight dependencies is harder. You will have to start mocking stuff that are part of the code you want to test. Not to mention that you might end up in a “mocking hell” situation.

This post presents a way of structuring a Flask application with a non flask-specific approach, following architectural patterns such as Clean Architecture and also guide you on the proccess of applying Dependency Injection to decouple your dependencies and allow you to have more scalability, better tests and an overall easier development experience.

What we are going to build

A random meme generator application that fetches images from a known repository and inserts user submitted text on top of the images.

Some key notes:

  • The app should fetch images from a local filesystem repository at first, but other sources must be supported, such as Amazon S3
  • Text is provided by the user using querystring parameters
  • Top and bottom text should be added to the base image, just as any good meme In this tutorial I’m not going to go into details of how to add text to the images
  • Meme is created from a random image fetched from whatever source we are using

The final code repository is available at gitlab: https://gitlab.com/dirodriguezm/dirodriguezm.gitlab.io and each incremental step is in it’s own tag, so you can track the progress and what was added on each increment.

The application in motion looks like this:

sample_request.png

Base Application

Find this part of the tutorial in the Base-Application tag

The first step is to build the application. We are going to follow a directory structure that resembles the Clean Architecture pattern.

|- assets
|- src
 |- application
 |- domain
   |- entities
 |- infrastructure
   |- gateways
  • assets: directory where we will store local images
  • src: all the aplication code
  • src/application: flask specific code: main app, views, etc
  • src/domain: core use cases and domain entities
  • src/infrastructure: application gateways, entity repositories, etc

I will asume that you know the basics of Clean Architecture or at least MVC patterns, so I’ll skip some concepts definition.

The code will work like this. The application module will have Flask code (the Flask app) and views (routes) will act as controllers for the rest of the app. Routes will parse parameters and then request the core services which are part of the domain module for the result. In this case we have only one core service or use case: the MemeGenerator class that adds text to images. The core service asks the repositories for data (images) and the repository is responsible for using gateways such as LocalImageGateway or S3ImageGateway to fetch images.

base_class_diagram.png

Application code

This is where the usual code of Flask framework goes. We initialize an app object with the createapp method and define a couple views, one for the root endpoint and one for the memes.

from flask import Flask, request, send_file
from domain.meme_generator import MemeGenerator
from infrastructure.image_repository import ImageRepository
from io import BytesIO

app = Flask(__name__)


@app.route("/")
def hello_world():
    return "<p>Hello, World!</p>"


@app.route("/generate_meme")
def generate_meme():
    repository = ImageRepository(
        "/home/diego/Projects/meme-generator/assets/local_images"
    )
    generator = MemeGenerator(image_repository=repository)
    args = request.args
    meme = generator.create_meme(
        args.get("top_text"), args.get("bottom_text", "")
    )
    bio = BytesIO()
    meme.image.save(bio, "PNG", quality=100)
    bio.seek(0)
    return send_file(bio, mimetype="image/png")

Notice how we have to instantiate all our classes here. The MemeGenerator uses an ImageRepository (at least we are not instantiating the repository inside the MemeGenerator) so we have to pass it as an argument.

Meme Generator

The meme generator class is the core use case handler. It fetches images using the repository, creates a Meme instance, calls the Meme creatememe method that adds text to the image and then returns the Meme instance.

from infrastructure.image_repository import (
    ImageRepositoryInterface,
)
from domain.entities.meme import Meme


class MemeGeneratorInterface:
    def create_meme(top_text: str, bottom_text: str) -> Meme:
        raise NotImplementedError()


class MemeGenerator(MemeGeneratorInterface):
    def __init__(self, image_repository: ImageRepositoryInterface):
        self.image_repository = image_repository

    def create_meme(self, top_text: str, bottom_text: str) -> Meme:
        img = self.image_repository.get_image()
        meme = Meme(img, top_text, bottom_text)
        meme.generate_meme()
        return meme

Meme

The Meme entity represents an image with text. It has the core usecase which is to add top and bottom text to an image.

from PIL.Image import Image
from PIL import ImageDraw
from PIL import ImageFont


class Meme:
    def __init__(self, image: Image, top_text: str, bottom_text: str):
        self.image = image
        self.top_text = top_text
        self.bottom_text = bottom_text

    def split_text(self, text: str) -> str:
        text_arr = text.split(" ")
        max_words = int(self.image.width / 150)
        counter = 0
        final_text = ""
        for word in text_arr:
            final_text += word
            if counter == max_words:
                final_text += "\n"
                counter = 0
            else:
                final_text += " "
            counter += 1
        return final_text

    def generate_meme(self):
        imdraw = ImageDraw.Draw(self.image)
        font_size = int(self.image.height / 12)
        font = ImageFont.truetype(
            "assets/fonts/impact.ttf",
            font_size,
        )
        x, y = (10, 10)
        split_top = self.split_text(self.top_text)
        split_bot = self.split_text(self.bottom_text)
        imdraw.text((x - 1, y - 1), split_top, font=font, fill=(0, 0, 0))
        imdraw.text((x + 1, y - 1), split_top, font=font, fill=(0, 0, 0))
        imdraw.text((x - 1, y + 1), split_top, font=font, fill=(0, 0, 0))
        imdraw.text((x + 1, y + 1), split_top, font=font, fill=(0, 0, 0))
        imdraw.text(
            (x, y),
            self.split_text(self.top_text),
            fill=(255, 255, 255),
            font=font,
        )
        x, y = (10, self.image.height - 30)
        imdraw.text((x - 1, y - 1), split_bot, font=font, fill=(0, 0, 0))
        imdraw.text((x + 1, y - 1), split_bot, font=font, fill=(0, 0, 0))
        imdraw.text((x - 1, y + 1), split_bot, font=font, fill=(0, 0, 0))
        imdraw.text((x + 1, y + 1), split_bot, font=font, fill=(0, 0, 0))
        imdraw.text(
            (x, y),
            split_bot,
            fill=(255, 255, 255),
            font=font,
        )

Image Repository

We are going to use the local filesystem images, stored in the assets directory, but the class itself receives the path as argument.

from PIL import Image
import random


class ImageRepositoryInterface:
    def get_image():
        raise NotImplementedError()


class ImageRepository(ImageRepositoryInterface):
    def __init__(self, path: str):
        self.path = path

    def get_image(self) -> Image.Image:
        r = random.randint(1, 10)
        try:
            return Image.open(f"{self.path}/image{r}.png")
        except FileNotFoundError:
            return Image.open(f"{self.path}/image{r}.jpg")

Base Application Conclusion

This is the basic functionality of our app. We don’t have any tests, only one image repository for local files and we have to do manual instantiation of each dependency in the view, even though we are using dependency injection, we have to do some manual work and if we were to test this code right now, we have the hardcoded repository dependency.

Now we are going to write a S3 gateway, so instead of the local filesystem we will fetch data from S3 (not really, but we will act like we do)

Multiple Data Sources

We will now introduce the Gateway class, that connects the external data source to our code. Repositories now use a gateway to fetch data.

gateway_class_diagram.png

Repository

Now uses gateway to fetch an image. We apply dependency injection by passing the gateway instance as an argument.

import random
from PIL import Image
from infrastructure.gateways.gateway import ImageGateway


class ImageRepositoryInterface:
    def get_image():
        raise NotImplementedError()


class ImageRepository(ImageRepositoryInterface):
    def __init__(self, gateway: ImageGateway):
        self.gateway = gateway

    def get_image(self) -> Image.Image:
        r = random.randint(1, 10)
        return self.gateway.get_image(r)

Gateway

The generic class looks like this

class ImageGateway:
    def get_image(self, index: int):
        raise NotImplementedError()

Local

Same as before

from infrastructure.gateways.gateway import ImageGateway
from PIL import Image


class LocalImageGateway(ImageGateway):
    def __init__(self, path: str):
        self.path = path

    def get_image(self, index: int) -> Image.Image:
        try:
            return Image.open(f"{self.path}/image{index}.png")
        except FileNotFoundError:
            return Image.open(f"{self.path}/image{index}.jpg")

S3

Still not implemented. Only created for demonstration purposes.

from infrastructure.gateways.gateway import ImageGateway


class S3ImageGateway(ImageGateway):
    def __init__(self, path: str):
        self.path = path

Application code

Same as before but now we instantiate the repository with a gateway instance.

@app.route("/generate_meme")
def generate_meme():
    gateway = LocalImageGateway(
        path="assets/local_images"
    )
    repository = ImageRepository(gateway=gateway)
    generator = MemeGenerator(image_repository=repository)
    args = request.args
    meme = generator.create_meme(
        args.get("top_text"), args.get("bottom_text", "")
    )
    bio = BytesIO()
    meme.image.save(bio, "PNG", quality=100)
    bio.seek(0)
    return send_file(bio, mimetype="image/png")

Gateways Conclusion

We have added gateways to fetch images from multiple sources. The repository now has a dependency on these gateways and we have to instantiate this in the view as well. You can now see how instantiating too many dependencies adds boilerplate to our view, so we could move all of these to a single point (called a container) and import the dependencies from the container.

We can do just that and call it a day, since our application is small and if we wanted to change dependencies for testing we could just change the container or create a different container for tests, but there’s a better solution that will allow our application to grow in a more manageably way, make our testing easier with swappable dependencies without changing containers and also create composite containers (and much more complex dependency management).

Dependency Injector

From the documentation: Dependency Injector is a dependency injection framework for Python. It helps implementing the dependency injection principle.

That is exactly what we need right now. With our code already structured in a way that applies Dependency Injection, now we have to integrate it with Dependency Injector.

Containers

from dependency_injector import containers, providers
from infrastructure.gateways.local_image_gateway import LocalImageGateway
from infrastructure.image_repository import ImageRepository
from domain.meme_generator import MemeGenerator


class Container(containers.DeclarativeContainer):
    wiring_config = containers.WiringConfiguration(
        packages=["application.views"]
    )
    config = providers.Configuration(ini_files=["config.ini"])
    image_gateway = providers.Singleton(
        LocalImageGateway, path=config.image_repository.path
    )
    image_repository = providers.Singleton(
        ImageRepository, gateway=image_gateway
    )
    meme_generator = providers.Singleton(
        MemeGenerator, image_repository=image_repository
    )

Essentially what we are doing here is defining how each dependency is going to be initialized. With Dependency Injector, we don’t really instantiate these clases. They are actually instantiated when the code that uses them is called and the instantiation happens according to this declarative container.

There are a few key elements in the above code: wiring, config and the actual dependencies.

First, the container is wired to a specific module. This means that said module can inject dependencies declared in this container. Injecting means that the module can use these classes without instantiating them explicitly.

Second, we can get parameters from a config file. We tell the container to look for this config in a ini file and then we can use any variable declared in the config file to instantiate our dependencies.

Lastly, we tell how our classes should be handled by the container. In this case we are using Singletons. This means that when we inject any of this dependencies, we would be using the same instance of the class. This is useful for database connections, for example, but in this case we don’t mind if the dependencies are Singletons or they get a new instance on each reference.

Now, to use the container we have to Inject to the module specified above, wich is a new views module inside the application directory.

Application code

Cleaner App code

We are now using blueprints for cleaner code and easier dependency injection

from flask import Flask
from containers.containers import Container
from application.views import meme_routes, root_route


def create_app() -> Flask:

    container = Container()
    app = Flask(__name__)
    app.container = container

    app.register_blueprint(meme_routes)
    app.register_blueprint(root_route)

    return app

Views module and blueprints

Each route is now part of a blueprint (learn about Flask blueprints here) and the route function can use the inject decorator from Dependency Injector to make dependencies available inside this function. The dependency is passed as an argument of the function wrapped with inject, specifying the MemeGenerator class as the injected dependency using Provide.

from flask import Blueprint, request, send_file
from domain.meme_generator import MemeGenerator
from io import BytesIO
from dependency_injector.wiring import inject, Provide
from containers.containers import Container

meme_routes = Blueprint("memes", __name__)


@meme_routes.route("/generate_meme")
@inject
def generate_meme(
    meme_generator: MemeGenerator = Provide[Container.meme_generator],
):
    args = request.args
    meme = meme_generator.create_meme(
        args.get("top_text"), args.get("bottom_text", "")
    )
    bio = BytesIO()
    meme.image.save(bio, "PNG", quality=100)
    bio.seek(0)
    return send_file(bio, mimetype="image/png")


root_route = Blueprint("root", __name__)


@root_route.route("/")
def hello():
    return "hello"

Tests

Let’s create a simple test and apply a dependency override to swap a class with a mock.

Mock Class

Our mock class uses the same interface as the production class. In this test we will be mocking only the gateway code, so the mock will extend ImageGateway class, but is going to have an attribute to handle different types of tests and return values accordingly.

For example if we want to test a successful response, the mock should return a PIL.Image class, but if we test failure states, we instead raise an error.

class ImageGatewayMock(ImageGateway):
    def __init__(self, test_type="SUCCESS"):
        self.test_type = test_type

    def get_image(self, index: int):
        if self.test_type == "SUCCESS":
            return Image.open("assets/local_images/image1.png")
        elif self.test_type == "FAILURE":
            raise Exception("Could not retrieve image")

Testing with dependency override

Dependency Injector allows us to easily change a dependency with another class that has the same Interface. For example here, we will swap the image_gateway attribute from the conainer that normally uses LocalImageGateway, to an ImageGatewayMock.

This is the essential part of DI. Our application code does not depend on LocalImageGateway, it instead depends on ImageGateway interface, so it does not care what we pass to it as long as it has the same methods.

Depenency Injector library just makes this easier for us, allowing to override the container classes with whatever we want as long as we use the same interface, so the code does’nt break.

import pytest
from src.application.app import create_app
from mocks.image_gateway_mock import ImageGatewayMock


@pytest.fixture
def app():
    app = create_app()
    app.config.update(
        {
            "TESTING": True,
        }
    )
    yield app
    app.container.unwire()


@pytest.fixture()
def client(app):
    return app.test_client()


def test_generate_meme(client, app):
    gateway_mock = ImageGatewayMock("SUCCESS")
    with app.container.image_gateway.override(gateway_mock):
        response = client.get("/generate_meme")

    assert response.status_code == 200


def test_generate_meme_gateway_error(client, app):
    gateway_mock = ImageGatewayMock("FAILURE")
    with app.container.image_gateway.override(gateway_mock):
        with pytest.raises(Exception) as e:
            response = client.get("/generate_meme?top_text=test")
            assert response.status_code == 500
            assert str(e) == "Could not retrieve image"

Conclusion

In this tutorial we created an application with Flask, that fetches image data from multiple sources and create a meme with it using user input. We did small increments, adding new functionalities and ultimately integrating the Dependency Injector library to manage dependencies.

Dependency injection makes the application more scalable, testable and overall more manageable. We easily tested code without accessing external services and without changing core application code.

Applying Dependency Injector to a Flask application is pretty straight forward. We had to prepare our code to be able to support dependency injection, of course, but as long as we make our dependencies rely on interfaces, we can easily integrate a dependency injection framework like this.