How you can improve Neural Networks with DANN

Using Domain-Adversarial Neural Networks for better performance

Devansh
14 min readMay 22, 2024

Reminder- I am flying to San Fransico b/w 24th-28th of this month. I will also be attending the RevOps AF event in San Diego from 28th to 30th May Come say hi if you’re in the neighborhood.

Executive Highlights (tl;dr of the article)

Thanks to their versatility, Neural Networks are a staple in most modern Machine Learning pipelines. Their ability to work with unstructured data is a blessing since it lets us allows us to partially -partially being important here- replace domain insight (expensive and hard to attain) with computational scale (cheaper and easier to attain).

However, turns out that blindly stacking layers and throwing money at problems doesn’t work as well as some people would like you to believe. There are several underlying issues with the training process that scale does not fix, chief amongst them being distribution shift and generalization. Neural Networks break when we give them inputs in unexpected ways/formats. For example, despite the impressive gains in LLM capabilities, this has been one of the most difficult parts for them to address-

Do Language Models really understand Language

There are several ways to improve generalization such as implementing sparsity and/or regularization to reduce overfitting and applying data augmentation to mithridatize your models and to stop them being such delicate princesses.

In this article, we will be looking at the super interesting, “Domain-Adversarial Neural Networks” (the intro), and the “Domain-Adversarial Training of Neural Networks” (more details) to look into the technique of Domain Adversarial Neural Networks (DANN) and how it can improve generalization. To understand the implications fully, we will first cover important background information-

  • What is Distribution Shift: Distribution shift, also known as dataset shift or covariate shift, is a phenomenon in machine learning where the statistical distribution of the input data (features or covariates) changes between the training and deployment environments. This can lead to a significant degradation in the performance of a model that has been trained on a specific data distribution when it encounters data from a different distribution.
Some of you are probably old enough to remember when the world went from black and white to having color. What was that like? Would love to hear from you old people. Image Source
  • Possible Sources of Distribution Shift: Distribution shift can arise from various sources, including sample selection bias, where training data doesn’t reflect real-world distribution; non-stationary environments, where data changes over time; domain adaptation challenges, where models trained on one domain struggle in another; data collection and labeling issues; adversarial attacks; and concept drift, altering relationships between features and the target variable (this is something I dealt with when trying to build AI that sells loans to people)dr.
  • Mitigating distribution shift: Good data + adversarial augmentation + constant monitoring works wonders. Some people try to use complex augmentation schemes, but this is pointless. Keep it simple- and throw in some random noise. Works really well. Also helps you avoid unnecessary complexity in your solutions.

With that covered, we will then talk about DANNs and why I like them so much-

  • What is DANN- DANNs are based on a simple observation- we know that a Neural Network (or any AI Model) has generalized well if it performs well on a related dataset that it has NOT been trained on. So train a model on reviews on Amazon (the source dataset), and see how well it does on reviews on Reddit (the target dataset). We want AI Models that perform like Jude for Real Vardrid and not like Sancho for United.
  • How DANNs Attain Peak Generalization- DANNs theoretically attain domain invariance by learning domain-invariant features- representations that are useful for the main task while being indiscriminate between the source and target domains. This is done by following Domain Adversarial Training.
  • Domain Adversarial Training: Domain-Adversarial Training (DAT) involves training a neural network with two competing objectives: to accurately perform the main task (e.g., classification) and to confuse a domain classifier that tries to distinguish between source and target domain data. This adversarial process encourages the network to learn features that are useful for the task but independent of the domain, leading to improved performance on the target domain even with limited labeled data. The key component enabling this is the gradient reversal layer, which acts as an identity function during the forward pass but reverses gradients during backpropagation, creating a minimax game between the feature extractor and the domain classifier.
Figure 1: The proposed architecture includes a deep feature extractor (green) and a deep label predictor (blue), which together form a standard feed-forward architecture. Unsupervised domain adaptation is achieved by adding a domain classifier (red) connected to the feature extractor via a gradient reversal layer that multiplies the gradient by a certain negative constant during the backpropagation-based training. Otherwise, the training proceeds standardly and minimizes the label prediction loss (for source examples) and the domain classification loss (for all samples). Gradient reversal ensures that the feature distributions over the two domains are made similar (as indistinguishable as possible for the domain classifier), thus resulting in the domain-invariant features

I really like the idea behind DANN since it is intuitive, functions as a good proxy for quantifying an otherwise complex measurement, and can be molded into various domains/models. Let’s spend the rest of this article exploring these ideas in more detail. The papers themselves are pretty old, but I think the idea is very relevant to building modern systems.

Our experiments on a sentiment analysis classification benchmark, where the target domain data available at training time is unlabeled, show that our neural network for domain adaption algorithm has better performance than either a standard neural network or an SVM, even if trained on input features extracted with the state-of-the-art marginalized stacked denoising autoencoders

I provide various consulting and advisory services. If you‘d like to explore how we can work together, reach out to me through any of my socials over here or reply to this email.

The effect of adaptation on the distribution of the extracted features (best viewed
in color). The figure shows t-SNE (van der Maaten, 2013) visualizations of the
CNN’s activations (a) in case when no adaptation was performed and (b) in
case when our adaptation procedure was incorporated into training. Blue points
correspond to the source domain examples, while red ones correspond to the target
domain. In all cases, the adaptation in our method makes the two distributions
of features much closer.

Join 150K+ tech leaders and get insights on the most important ideas in AI straight to your inbox through my free newsletter- AI Made Simple

Understanding Distribution Shift

I’m going to skip re-explaining what distribution shift is, since we already did it in the highlights and I believe in your intelligence. If you want extra help, take a look at a more traditional example of what a distribution shift looks like in AI-

Source. Models trained on feature 1 will clearly need to account for these changes.

People can often underestimate how often and in how many ways shift can occur. Let’s work through some examples before covering how to fix it. Some of these will overlap with each other, but I’m keeping them separate to demonstrate the different perspectives one can take when addressing these challenges.

Sample Selection Bias

Sample selection bias occurs when the training data doesn’t accurately represent the real-world data distribution. This mismatch can lead to models that perform well on the training data but generalize poorly to real-world scenarios.

  • Example 1: Facial Recognition Systems and Demographic Bias:
    A while back, there was a story of a SOTA Vision model that was meant to track a football during a match. However, it ended up tracking a referee's bald head. I guess that the dataset didn’t have any top-angle pictures of bald people, causing it to get confused by the smoothness.
Image Source. I’m sure the X-Risk crew will work overtime to convince you that this is actually proof that AI has decided that Bald People should be removed from the gene pool and is asking us to use their heads as balls.
  • Example 2: Medical Diagnosis and Hospital-Specific Biases:
    A disease prediction model trained on data from a specific hospital might not generalize well to patients from different hospitals or demographics. This is because medical practices, patient demographics, and even disease prevalence can vary significantly across different healthcare settings.

Non-Stationary Environments

In many real-world applications, data is not static. As the world changes, models trained on historical data might become less accurate as time progresses and the data distribution evolves.

  • Example 1: Financial Markets and Evolving Trends:
    A stock prediction model trained on historical stock prices might become less accurate over time due to factors like evolving market dynamics, new regulations, or unforeseen global events.
  • Example 2: Language Models and Shifting Language:
    Language Models trained on social media data need to adapt to constantly evolving language use, slang (I recently learned about the word brainrot), and emerging topics. Different generations might even completely change words (the word terrible has seen a large drift from its original meaning) or use different kinds of communications (one generation might be inspired by shows like Yes Minister and Silicon Valley to be more satirical). Not to mention that culture is always shifting b/c of interactions with over groups, which might cause new language patterns/changing rhetorical devices.

Domain Adaptation Challenges

Domain adaptation addresses the challenge of applying models trained on one domain (the source domain) to a different but related domain (the target domain). Even if the task remains the same, differences in data characteristics can significantly impact performance.

  • Example 1: Medical Image Segmentation — Different Modalities:
    A model trained to segment tumors in brain MRI scans might not perform well when applied to CT scans of the brain. Although both modalities provide images of the brain, differences in imaging techniques and image characteristics require domain adaptation.
  • Example 2: Natural Language Processing — Reviews vs. Tweets:
    A sentiment analysis model trained on product reviews might struggle with social media posts like tweets. The shorter length, informal language, and use of slang in tweets create a different domain compared to the more structured language of product reviews.

Adversarial Attacks: Deliberately Fooling the Model

Adversarial attacks involve intentionally crafting malicious input data to exploit vulnerabilities in machine learning models. These attacks aim to cause the model to make incorrect predictions, potentially leading to severe consequences.

  • Example 1: Image Recognition — Adversarial Noise:
    Adding subtle, carefully crafted noise to an image can be imperceptible to humans but cause an image recognition system to misclassify it. For example, an attacker might make minor modifications to a stop sign image, causing a self-driving car’s vision system to misinterpret it.

The results show that 67.97% of the natural images in Kaggle CIFAR-10 test dataset and 16.04% of the ImageNet (ILSVRC 2012) test images can be perturbed to at least one target class by modifying just one pixel with 74.03% and 22.91% confidence on average. We also show the same vulnerability on the original CIFAR-10 dataset. Thus, the proposed attack explores a different take on adversarial machine learning in an extreme limited scenario, showing that current DNNs are also vulnerable to such low dimension attacks.

-One Pixel Attack

  • Example 2: Spam Detection — Evolving Spam Techniques:
    Spammers constantly evolve their techniques to bypass spam filters. A spam detection model needs to adapt to these evolving tactics, such as using subtle misspellings or inserting hidden text, to remain effective. We are working on something similar for exploring the security of Multi-Modal AI and how adversarial hackers would be able to bypass any security checks there.
  • Example 3: Language Model Attacks:
    Similar to Vision Models, we can attack language models to either jailbreak them- to undo the protections implemented through alignment or guardrails- or to have them spit out their training data. This usually involves playing with types of prompt injection, the most notable of which was Deepmind’s Poem Attack.
Extracting data from ChatGPT

Concept Drift

Concept drift occurs when the underlying relationship between input features and the target variable changes over time. If we’re not careful, this leads to completely destroying old models.

  • Example 1: Credit Scoring — Evolving Economic Factors:
    Factors that predicted high creditworthiness a few years ago might not hold true today due to changing economic conditions or consumer behavior. A credit scoring model needs to adapt to these evolving relationships to ensure fair and accurate assessments.
  • Example 2: Recommender Systems — Shifting User Preferences:
    Recommender systems rely on understanding user preferences, which can change over time. A model that recommends products based on past purchase history needs to adapt to evolving tastes and preferences to remain relevant.

Fixing distribution shift revolves around the usual suspects-

Data Quality and Quantity: Ensure training data encompasses the variety expected in real-world deployment. On top of this, expand the training dataset by applying transformations (e.g., rotations, crops, color shifts) to existing data, improving robustness to variations. Some teams try to get too clever with Data Augmentation, implementing complex schemes to address the limitations of your dataset. You probably don’t need that. Before you decide to get too clever, consider the following statement from TrivialAugment-

Second, TA teaches us to never overlook the simplest solutions. There are a lot of complicated methods to automatically find augmentation policies, but the simplest method was so-far overlooked, even though it performs comparably or better

Lastly, you should remember that there is a lot of good in exploring different feature extractions to build richer representations of your data. That’s one of the best ways to improve model performance.

Ensemble Methods: Combine predictions from multiple models trained on different subsets of data or with different architectures to improve resilience to individual model biases.

Continual Learning and Model Updates: Implement systems that continuously learn and adapt to new data, ensuring models remain relevant as data distributions evolve. Retraining gets expensive, which is another reason why simple models are king.

Robust Evaluation and Monitoring: Evaluate models on diverse datasets and monitor performance over time to detect and address potential distribution shifts.

With that exploration of distribution shift out of the way, let’s look into DANN and how it helps us handle domain shift better.

We demonstrate the success of our approach for two distinct classification problems (document sentiment analysis and image classification), where state-of-the-art domain adaptation performance on standard benchmarks is achieved. We also validate the approach for descriptor learning task in the context of person re-identification application.

A cultist’s introduction to Domain Adversarial Training

The key to domain adaptation in this training is the training process. The secret to DANN’s amazing performance is in their architecture.

We show that this adaptation behaviour can be achieved in almost any feed-forward model by augmenting it with few standard layers and a new gradient reversal layer. The resulting augmented architecture can be trained using standard backpropagation and stochastic gradient descent, and can thus be implemented with little effort using any of the deep learning packages.

Let’s break that down in more depth. To do so, we’re going to pull up the architecture diagram again.

More AI researchers should use pretty coloring in their diagrams

The core idea is to train a model to perform well on a specific task while simultaneously making it difficult to distinguish between the source domain and the target domain.

Here are the main components-

Feature Extractor: The first part of the model learns to extract features from the input data. These features should be relevant for the main task (e.g., classification) while being as similar as possible across both domains.

Label Predictor: This component focuses on the primary task, utilizing the extracted features to make predictions (e.g., classify images). It is typically trained solely on labeled data from the source domain.

Domain Classifier: This part acts as an adversary to the feature extractor. Its goal is to determine whether a given feature representation originates from the source or target domain.

Adversarial Training: The training process involves a continuous back-and-forth between the feature extractor and the domain classifier:

  • The feature extractor tries to learn representations that fool the domain classifier, making it difficult to distinguish between the domains.
  • The domain classifier, in turn, tries to improve its ability to differentiate between the domains based on the provided features.

More mathematically, these equations are what determine the training-

This is mostly similar to standard deep learning. Here the theta reps the params, which the μ is the learning rate. The important bit to note is in equation 13. First, we got λ, which weighs our domain loss. The rest is business as usual with, “The only difference is that in (13), the gradients from the class and domain predictors are subtracted, instead of being summed (the difference is important, as otherwise SGD would try to make features dissimilar across domains in order to minimize the domain classification loss).

Pretty neat, huh? Let’s now cover the final important player.

Gradient Reversal Layer (GRL): This special layer is a crucial element in DAT. It sits between the feature extractor and the domain classifier. During the forward pass (prediction), it simply passes the gradients through. However, during the backward pass (learning), it reverses the direction of the gradients.

This works because it exploits the dynamic between the domain classifier, the main task, and the feature extractor. The feature extractor receives two conflicting signals during backpropagation:

  • From the label predictor: Update your parameters to improve task performance (e.g., classification accuracy).
  • From the domain classifier (via the GRL): Update your parameters in the opposite direction to make it harder for the domain classifier to distinguish between domains.

The end result is the extraction of features that are still useful for the main task but increasingly less informative about the domain origin. In other words, it learns domain-invariant features.

To a degree, this reminds me of the excellent MIT paper , Adversarial Examples Are Not Bugs, They Are Features”. The authors presented us with a simple theory- there are two types of informative features in image classifiers- fragile features and robust features. Both features can be used to classify images with high accuracy but fragile features will break on perturbation while robust features will not.

Adversarial Attacks attack the fragile features.

The integration of GRL seems to play a similar role to a robust feature extraction- it enables us to pull the most robust features.

Putting all these together gives us the DANN. It’s elegant, intuitive, and performant. There are a lot of new AI papers that get published everyday, but I think it can be helpful to look through older publications to redevelop ideas from first principles. Hopefully, this piece gets you to explore this side of research more. What are some fields of AI that you feel aren’t respected enough? Would love to know.

Accuracy evaluation of different DA approaches on the standard Office (Saenko
et al., 2010) data set. All methods (except SA) are evaluated in the “fully-
transductive” protocol (some results are reproduced from Long and Wang, 2015).
Our method (last row) outperforms competitors setting the new state-of-the-art.

If you liked this article and wish to share it, please refer to the following guidelines.

That is it for this piece. I appreciate your time. As always, if you’re interested in working with me or checking out my other work, my links will be at the end of this email/post. And if you found value in this write-up, I would appreciate you sharing it with more people. It is word-of-mouth referrals like yours that help me grow.

I put a lot of effort into creating work that is informative, useful, and independent from undue influence. If you’d like to support my writing, please consider becoming a paid subscriber to this newsletter. Doing so helps me put more effort into writing/research, reach more people, and supports my crippling chocolate milk addiction. Help me democratize the most important ideas in AI Research and Engineering to over 100K readers weekly.

Help me buy chocolate milk

PS- We follow a “pay what you can” model, which allows you to support within your means. Check out this post for more details and to find a plan that works for you.

I regularly share mini-updates on what I read on the Microblogging sites X(https://twitter.com/Machine01776819), Threads(https://www.threads.net/@iseethings404), and TikTok(https://www.tiktok.com/@devansh_ai_made_simple)- so follow me there if you’re interested in keeping up with my learnings.

Reach out to me

Use the links below to check out my other content, learn more about tutoring, reach out to me about projects, or just to say hi.

Small Snippets about Tech, AI and Machine Learning over here

AI Newsletter- https://artificialintelligencemadesimple.substack.com/

My grandma’s favorite Tech Newsletter- https://codinginterviewsmadesimple.substack.com/

Check out my other articles on Medium. : https://rb.gy/zn1aiu

My YouTube: https://rb.gy/88iwdd

Reach out to me on LinkedIn. Let’s connect: https://rb.gy/m5ok2y

My Instagram: https://rb.gy/gmvuy9

My Twitter: https://twitter.com/Machine01776819

--

--

Devansh

Writing about AI, Math, the Tech Industry and whatever else interests me. Join my cult to gain inner peace and to support my crippling chocolate milk addiction