A look into Deep Java Library!

When you think about building machine learning apps, Java is not the first language that comes to mind, probably not even in the top 3 or 5! But Java has proved time and again that it is capable of modernising itself, and even if it's not the first choice for job for many use cases, it offer a choice for the 10 million developers that are using it.

A few weeks back I started exploring a new Java library called DJL, an ope source, engine-agnostic Java framework for deep learning. In this post we're going to understand some of djl capabilities by building a speech recognition application.

Crooked Lake, IN - 7/4/17
Photo by Hunter Harritt / Unsplash

Deep Java Library

DJL was first released in 2019 by Amazon web services, aiming to offer simple to use easy to get started machine learning framework for java developers. It is offers  multiple java APIs for simplifying, training, testing, deploying, analysing, and predicting outputs using deep-learning models.

DJL APIs abstract away the complexity involved in developing Deep learning models, making them easy to learn and easy to apply. With the bundled set of pre-trained models in model-zoo, users can immediately start integrating Deep learning into their Java applications.

Showtime

As I mentioned earlier we're building a small Speech Recognition application. The backend is built using java 17 and Spring boot 3.1. The Frontend is built with React JS 18.2. Full application code is shared in this repo.

via GIPHY

Backend configuration

First of all, we'd need to add the necessary DJL dependencies. I am using DJL version 0.22.1, the latest release as of this writing. We'd need two specific djl dependencies for this application

  • djl-api: DJL core api.
  • pytorch-engine: The DJL implementation for PyTorch Engine, enabling to load and use pytorch built models.
    <dependency>
      <groupId>ai.djl</groupId>
      <artifactId>api</artifactId>
      <version>${djl.version}</version>
    </dependency>
    <dependency>
      <groupId>ai.djl.pytorch</groupId>
      <artifactId>pytorch-engine</artifactId>
      <version>${djl.version}</version>
    </dependency>

We'll need next to configure DJL, specifying which model we want to use for inference (prediction).
The loadModel method defines a Criteria class to locate the model we want to use. In the Criteria we especified:

  • Engine: Which engine you want your model to be loaded. Pytorch in our case
  • Input/Output data type: defines desired input (Audio in our example) and output data type (transcription)
  • model url: Defines where the model is located,
  • Translator: Specifies custom data processing functionality to be used to ZooModel

We then load the pre-trained model using (ModelZoo)[https://javadoc.io/doc/ai.djl/api/latest/ai/djl/repository/zoo/ModelZoo.html] directly using a uri for convinience. The model we'll be using (wav2vec)[https://arxiv.org/abs/2006.11477] model, a speech model that accepts a float array corresponding to the raw waveform of the speech signal.

@Configuration
public class ModelConfiguration {

  private static final Logger LOG = LoggerFactory.getLogger(ModelConfiguration.class);
  
  @Bean
  public ZooModel<Audio, String> loadModel() throws IOException, ModelException, TranslateException {
    // Load model.
    String url = "https://resources.djl.ai/test-models/pytorch/wav2vec2.zip";
    Criteria<Audio, String> criteria =
        Criteria.builder()
            .setTypes(Audio.class, String.class)
            .optModelUrls(url)
            .optTranslatorFactory(new SpeechRecognitionTranslatorFactory())
            .optModelName("wav2vec2.ptl")
            .optEngine("PyTorch")
            .build();

    return criteria.loadModel();
  }

  @Bean
  public Supplier<Predictor<Audio, String>> predictorProvider(ZooModel<Audio, String> model) {
    return model::newPredictor;
  }

}

That's pretty much all the configuration we need in order to start using our model. The service class sumply make calls the predictor for inference.

  @Resource
  private Supplier<Predictor<Audio, String>> predictorProvider;

  public String predict(InputStream stream) throws IOException, ModelException, TranslateException {
    Audio audio = AudioFactory.newInstance().fromInputStream(stream);

    try (var predictor = predictorProvider.get()) {
      return predictor.predict(audio);
      }
    }

The rest is pretty much simple Spring boot configuration.

Frontend Configuration

The frontend make use of the amazing react-audio-analyser library, offering the possibility to record an audio from the browser and convert it to wav format. The rest is pretty much straightforward, only making a REST call to transcription endpoint and showing the result in the browser.