Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Next Article in Journal
Improved YOLOv8 for Dangerous Goods Detection in X-ray Security Images
Previous Article in Journal
LIME-Mine: Explainable Machine Learning for User Behavior Analysis in IoT Applications
Previous Article in Special Issue
NLOCL: Noise-Labeled Online Continual Learning
 
 
Font Type:
Arial Georgia Verdana
Font Size:
Aa Aa Aa
Line Spacing:
Column Width:
Background:
Article

KRA: K-Nearest Neighbor Retrieval Augmented Model for Text Classification

1
School of Computer Science (National Pilot Software Engineering School), Beijing University of Posts and Telecommunications, Beijing 100876, China
2
School of International Chinese Language Education, Beijing Normal University, Beijing 100875, China
3
School of Computer Science & Technology, Beijing Institute of Technology, Beijing 100081, China
*
Author to whom correspondence should be addressed.
Electronics 2024, 13(16), 3237; https://doi.org/10.3390/electronics13163237
Submission received: 18 July 2024 / Revised: 12 August 2024 / Accepted: 14 August 2024 / Published: 15 August 2024
(This article belongs to the Special Issue Emerging Theory and Applications in Natural Language Processing)

Abstract

:
Text classification is a fundamental task in natural language processing (NLP). Deep-learning-based text classification methods usually have two stages: training and inference. However, the training dataset is only used in the training stage. To make full use of the training dataset in the inference stage in order to improve model performance, we propose a k-nearest neighbors retrieval augmented method (KRA) for deep-learning-based text classification models. KRA works by first constructing a storage system that stores the embeddings of the training samples during the training stage. During the inference stage, the model retrieves the top k-nearest neighbors of the testing text from the storage. Then, we use text augmentation methods to expand the retrieved neighbors, including traditional augmentation methods and a large language model (LLM)-based method. Next, the method weights the augmented neighbors based on their distances from the target text and incorporates their labels into the inference of the final results accordingly. We evaluate our KRA method on six benchmark datasets using four commonly used deep learning models: CNN, LSTM, BERT, and RoBERTa. The results demonstrate that KRA significantly improves the classification performance of these models, with an average accuracy improvement of 0.3% for BERT and up to 0.4% for RoBERTa. These improvements highlight the effectiveness and generalizability of KRA across different models and datasets, making it a valuable enhancement for a wide range of text classification tasks.

1. Introduction

Text classification is a fundamental task in the field of natural language processing (NLP) [1] wherein a given text is categorized into one or more predefined categories. The applications of text classification are diverse and include spam detection [2], sentiment analysis [3], relationship extraction [4], etc. Deep-learning-based methods are commonly used for text classification. Unlike traditional machine learning models, deep learning models are free from feature engineering, which can be time-consuming and domain-expertise-intensive. Deep-learning-based methods like recurrent neural networks (RNNs) [5] and convolutional networks (CNNs) [6] have been successfully employed. Additionally, transformer-based [7] pre-trained language models such as BERT [8], RoBERTa [9], Xlnet [10], and GPT-3 [11] have recently demonstrated remarkable performance in various NLP tasks. Text classification methods based on deep learning usually include two stages: training and inference. However, existing methods typically use the training dataset only during the training phase, meaning that its semantic information is not fully explored or utilized during the inference stage.
To better leverage the semantic information in the training dataset to boost the effectiveness of the model, a viable scenario is to use retrieval augmented methods [12], which have attracted lots of attention from the computational linguistics community recently. Retrieval augmented techniques have found extensive application in various natural language processing tasks, including but not limited to dialogue response generation [13], machine translation [14], language modeling [15], and named entity recognition [16]. However, it should be noted that the effectiveness of these approaches is evident only when they are implemented with a large retrieval set [17]. Therefore, when the retrieval set is not that large, a possible way is to use a text augmentation method to expand the size of it.
In this paper, we introduce KRA, a novel enhancement component for deep learning text classification models. Firstly, KRA constructs a storage system during the training stage to store the embeddings of the training set. Then, the model retrieves the top k-nearest neighbors of the testing text from the training storage during the inference stage. To expand the retrieval set, we generate additional texts based on the retrieved training neighbors via a text augmenting module. We use both traditional text augmentation methods and a large language model (LLM)-based method. Finally, the augmented neighbors are weighted by their distances from the target text, with their labels being accordingly incorporated into the final results. The proposed method, KRA, enables the model to leverage information from the training set and improve classification performance. Experiments conducted on six benchmarks demonstrate the efficacy of KRA across four commonly used deep-learning models: CNN, LSTM, BERT, and RoBERTa. We summarize our contributions as follows:
  • We present a new approach, KRA, for enhancing deep learning models used in text classification. This method leverages training dataset information and is applicable to a broad range of deep learning structures without requiring additional training.
  • We use both traditional and LLM-based text augmentation methods to enhance the performance of the KNN classifier.
  • Experiments conducted on six benchmarks demonstrate the efficiency of KRA across four commonly used deep learning models: CNN, LSTM, BERT, and RoBERTa.
The rest of the paper is organized as follows: Section 2 provides a comprehensive review of related work in the field of text classification and data augmentation. Section 3 details the proposed KRA method, including the k-nearest neighbors classifier and text augmentation techniques. Section 4 describes the experimental setup, including datasets, evaluation metrics, and implementation details, and presents the experimental results and discusses the findings. Finally, Section 5 concludes the paper and outlines potential directions for future research.

2. Related Work

2.1. Text Classification

In the field of natural language processing (NLP), deep learning models have been widely utilized for text classification tasks due to their effectiveness in handling large amounts of textual data and their ability to automatically extract meaningful patterns and features. Various deep structures have been developed and applied, each offering unique advantages for processing and understanding text.
One of the earliest types of deep learning models used for text classification tasks are recurrent neural networks (RNNs). RNNs, including their more sophisticated variant, long short-term memory (LSTM) networks, are particularly suited for sequential data due to their ability to maintain information over time. LSTMs, as discussed in the works of Liu et al. [5,18], address the vanishing gradient problem encountered in traditional RNNs and can capture long-range dependencies, making them effective for understanding context in sequences of text.
In addition to RNNs, convolutional neural networks (CNNs) have also demonstrated strong performance in text classification tasks. CNNs, which are typically used for image processing, have been adapted for NLP tasks, as highlighted by Kim [6] and Shen et al. [19]. By applying convolutional filters across text embeddings, CNNs can detect local patterns such as key phrases or important n-grams, which contributes to their robustness in text classification.
A more recent and transformative advancement in NLP has been the development of transformer-based models. These models, such as Bidirectional Encoder Representations from Transformers (BERT) [8] and RoBERTa [9], have significantly advanced the state-of-the-art in numerous NLP tasks, including text classification. Transformers utilize self-attention mechanisms to weigh the importance of different words in a sentence relative to each other, allowing them to capture complex dependencies and nuanced meanings across the entire text. Unlike earlier models, transformers process entire sentences simultaneously rather than sequentially, which enables them to be highly parallelizable and efficient in handling large datasets.
The widespread adoption of deep learning methods in text classification is primarily due to their ability to automatically learn high-level features from raw textual data, eliminating the need for extensive manual feature engineering. Traditional machine learning models often rely on handcrafted features and domain expertise to achieve good performance, which can be both time-consuming and resource-intensive. Deep learning models, on the other hand, leverage large datasets and hierarchical feature learning to discern intricate patterns and representations that might be overlooked by manual feature engineering.
Moreover, the advent of pre-trained language models, such as BERT and RoBERTa, has revolutionized the way text classification tasks are approached. These models are trained on vast amounts of text from diverse sources and can be fine-tuned for specific tasks with relatively small amounts of labeled data. This transfer learning approach not only enhances performance but also reduces the computational and time costs associated with training models from scratch.
The utilization of deep learning models in text classification tasks is driven by their ability to efficiently and effectively learn from raw text data, obviating the need for meticulous feature engineering. The evolution from RNNs and CNNs to transformer-based models like BERT and RoBERTa continues to push the boundaries of what is possible in the realm of natural language processing, offering ever-improving tools for understanding and classifying text.

2.2. K-Nearest Neighbors in NLP

The k-nearest neighbors (KNN) algorithm is a well-established and intuitive approach that relies on similarities between data points to make predictions. It functions by identifying the ’k’ closest data points in the feature space to a query point and then making predictions based on the properties of these neighbors. In recent years, studies have leveraged KNN in innovative ways to enhance various natural language processing (NLP) tasks, demonstrating that its utility extends beyond traditional usage.
For instance, Khandelwal [15] proposes a novel framework that integrates nearest neighbors to augment the predictions made by a language model. This approach involves retrieving similar instances from a training corpus and using them to inform the model’s output, thereby enhancing its generalization capabilities and robustness. This method exemplifies how KNN can be combined with neural language models to improve predictive accuracy by incorporating contextual insights from similar past instances.
Similarly, Kassner [20] explores the application of KNN in a question-answering task. By generating additional predictions through KNN, Kassner demonstrates that the performance of BERT, a powerful transformer-based model, can be significantly improved. This involves using KNN to retrieve possible candidate answers from a large knowledge base, thereby providing the model with supplementary information that aids with producing more accurate responses.
In another study, Li [21] utilizes a KNN-based classifier to fine-tune BERT, a pre-trained language model, specifically for enhancing its representation in text classification tasks. This method involves leveraging the nearest neighbors to refine the model’s understanding of different classes, which results in improved classification performance. By integrating KNN with BERT, the study highlights the advantages of combining distance-based learning with deep learning architectures to achieve better task-specific representations.
Borgeaud [17] employs KNN for a different purpose: to enrich text generation tasks. In this work, KNN is used to retrieve nearest neighbors from an enormous retrieval set comprising trillions of tokens. The retrieved neighbors provide valuable contextual clues that the generation model can use to produce more coherent and contextually appropriate text. This massive-scale application of KNN illustrates its potential in handling extensive datasets and improving the diversity and quality of generated content.
Notably, these methods all utilize KNN to augment pre-trained language models rather than relying solely on the KNN classifier as a decision-maker. By doing so, they leverage the strengths of both KNN and deep learning models. KNN’s ability to provide context-sensitive insights is combined with the powerful feature extraction and generalization capabilities of pre-trained language models. This hybrid approach maximizes the performance and flexibility of NLP applications, making it a valuable strategy for enhancing various tasks, from language modeling and question answering to text classification and generation.
The integration of KNN with advanced pre-trained language models has opened up new avenues for improving NLP tasks. These approaches highlight how traditional algorithms like KNN can be adapted and extended to complement modern deep learning techniques, leading to more sophisticated and effective solutions in the rapidly evolving field of natural language processing.

2.3. Text Augmentation

Text augmentation is a widely used technique in natural language processing (NLP) that aims to generate new training data by applying various transformations to existing data. The primary objective of text augmentation is to enhance the diversity and quantity of training data, which can lead to improved performance of NLP models. This technique has gained significant popularity in NLP due to the scarcity of labeled data, which is both expensive and time-consuming to create.
Text augmentation methods help address the challenge of limited labeled data by creating synthetic examples that mimic the properties of real data. By diversifying the training set, these methods enable models to generalize better and become more robust to different input variations.
Prior research has introduced several methodologies for augmenting data in the field of natural language processing. A notable investigation involves the creation of novel data by translating English sentences into French and subsequently translating them back into English [22]. This technique, known as back-translation, leverages the differences in language structure and vocabulary to produce varied yet semantically equivalent text. The process introduces subtle variations that can enhance the model’s ability to understand and generate diverse expressions of the same idea.
Additionally, other scholars employ data noising as a means of smoothing [23]. This method involves intentionally introducing noise into the data, such as by randomly swapping or deleting words, to teach the model to be resilient to imperfect inputs. This approach helps models handle real-world scenarios where data may be noisy or contain errors.
Predictive language models have also been harnessed to substitute synonyms for text augmentation [24]. By using models like BERT or GPT, researchers can identify and replace words with their contextually appropriate synonyms, thereby generating variations of the original text that retain the same meaning. This method leverages the rich semantic understanding embedded in pre-trained language models to ensure that replacements are contextually relevant.
Another effective method [25] involves operations such as synonym replacement, permutation, and random deletion on the text to enhance its semantic richness. Known as easy data augmentation (EDA), this set of techniques aims to create diverse and meaningful variations by modifying the text in simple yet effective ways. By randomizing certain elements, EDA helps to prevent models from overfitting to specific patterns in the data.
The emergence of large language models (LLMs) has brought new opportunities for text augmentation. Text augmentation methods based on large language models, such as AugGPT [26], have achieved promising results. These methods utilize the powerful generative capabilities of LLMs to create high-quality synthetic data. For instance, AugGPT can generate contextually coherent and semantically rich variations of input text, providing an additional layer of diversity and complexity to the augmented data.
In recent research, a novel large language model (LLM)-based text augmentation approach [27] has been proposed to enhance personality detection from social media posts. This method addresses the scarcity of ground-truth personality traits and the limitations of current models by leveraging LLM to generate semantic, sentiment, and linguistic augmentations. Through contrastive learning, this approach effectively captures psycho-linguistic information, leading to superior performance in personality detection compared to state-of-the-art methods.
In general, the goal of text augmentation is to produce sensible and diverse new samples that maintain semantic consistency. Effective augmentation strategies ensure that the generated data remain relevant and valuable for training, enhancing the model’s performance across various NLP tasks. In conclusion, text augmentation is a crucial technique in NLP for improving model performance, especially in scenarios with limited labeled data. By employing approaches such as back-translation, data noising, synonym substitution, and leveraging large language models, researchers and practitioners can create diverse and robust training datasets. These advancements in text augmentation continue to push the boundaries of what is possible in the realm of natural language processing, enabling the development of more sophisticated and capable models.

3. Proposed Method

To better utilize the semantic information of the training dataset and improve the performance of the model, we propose KRA, which consists of a classic k-nearest neighbors classifier and a text augmentation module. The k-nearest neighbors classifier leverages semantic information from the training dataset to aid in classification decision-making. The text augmentation module is designed to improve the quality of the training dataset, enhance the generalization ability of the k-nearest neighbors classifier, and ensure the effectiveness of KRA when the training dataset is not big enough. Figure 1 shows the overall structure of the model.

3.1. Datastore

Let f ( · ) be a function that maps the target text to a fixed-length vector representation computed by a pre-trained language model. For example, let c be the current text sequence; in BERT, f ( c ) represents mapping the text sequence c to a 768-length (default hidden layer length of BERT) vector representation. Then, for the i-th example ( c i , l i ) D in the training set D, we define a key–value pair ( k i , v i ) . The key k i represents the vector representation of the text sequence c i , i.e., f ( c i ) ; the value v i represents the label of the text sequence c i . Therefore, the data storage ( K , V ) stores the vector representations and labels of all of the text sequences in the training dataset.
( K , V ) = { ( f ( c i ) , l i ) | ( c i , l i ) D }

3.2. Text Augmentation

To enhance the diversity of retrieved neighbors, we employ both traditional and LLM-based text augmentation methods to expand the retrieved neighbors. Table 1 illustrates examples of different augmentation methods.

3.2.1. LLM-Based Method

We utilize GPT-3.5 for text rewriting, with our prompt outlined in Table 2.

3.2.2. Traditional Methods

We use five traditional methods to implement text augmentation: synonym replacement (SR), random insertion (RI), random swap (RS), random deletion (RD), and back translation (BT).
The SR technique expands the vocabulary expression ability of the dataset by replacing some words in the original text with their synonyms, introducing some semantic changes. The RI technique randomly inserts some additional words or phrases into the original text, simulating noise or interference in real-world scenarios. The RS technique randomly swaps words or phrases in the original text, perturbing the syntax and semantics to a certain extent. The RD technique deletes words or phrases in the original text with a certain probability, simulating situations with partial information missing. The BT technique is a machine-translation-based method that generates new samples with semantic perturbations by translating the original text into other languages and then translating the translation back into the original language, introducing greater semantic differences and enriching the diversity of the dataset.
By using these methods in combination, we can effectively enhance the training dataset, improving the richness and diversity of the neighbor data.

3.3. Inference

During the inference process, given the test set text x, the model generates the f ( x ) vector and then uses this vector to search the datastore, obtaining the k-nearest neighbors based on the Euclidean distance. Following the approach outlined in [15], we use the Euclidean distance rather than the cosine similarity. Then, these k-nearest neighbors corresponding to the original text C are input into the text augmentation module, and K augmented neighbor vectors f ( C ) are obtained. The weights W for each vector are calculated based on the distance as follows:
W = S o f t m a x ( | D | )
Here, D is the Euclidean distance between each augmented neighbor vector f ( C ) and f ( x ) . Then, the probability distribution of the neighbor labels is calculated and denoted as K N N ( x ) :
K N N ( x ) = W L
Here, L is the set of labels for the augmented neighbors. Finally, combining the probability distribution S o f t m a x ( f ( x ) ) given by the model, we obtain the probability distribution s:
s = λ K N N ( x ) + ( 1 λ ) S o f t m a x ( f ( x ) )
The ratio λ is a parameter used to adjust the probability distribution of the model itself and the probability distribution obtained from the neighbor vectors.

4. Experiments

4.1. Experiment Setup

4.1.1. Datasets

We use six benchmark datasets to evaluate the effectiveness of our proposed method. Here is the information about these datasets:
The MR dataset [28] is a well-known benchmark in the field of sentiment analysis. It comprises a collection of movie reviews, each represented by a single sentence, making it a useful corpus for training and evaluating models on short, sentiment-laden texts. The dataset is balanced, containing 5331 positive reviews and 5331 negative reviews, providing an equal representation of both sentiment classes. One of the key features of the MR dataset is its use of 10-fold cross-validation by random splitting, which is a common practice for testing and validating models on this dataset. This method involves partitioning the dataset into ten equal subsets and using nine subsets for training and one subset for testing in each iteration. This process is repeated ten times, with each subset being used as the test set once. The results are then averaged to provide a robust evaluation metric, minimizing the impact of any particular data split and ensuring that the model’s performance is not overly dependent on a specific subset of the data.
The MPQA [29] dataset is a widely used opinion dataset in the field of natural language processing. It contains annotated opinions and sentiments, making it a valuable resource for tasks such as sentiment analysis and opinion mining. The dataset features two class labels, positive and negative, which are commonly used for opinion polarity detection sub-tasks. Specifically, the MPQA dataset comprises 10,606 sentences extracted from news articles sourced from various news outlets, providing a diverse and representative sample of real-world opinions. These sentences are annotated to indicate their sentiment and comprise 3311 positive texts and 7293 negative texts.
The IMDB dataset [30] is a well-known sentiment analysis dataset that contains 50,000 comments extracted from the IMDB website, an extensive database of information and reviews about movies. This dataset is split evenly with 25,000 training samples and 25,000 testing samples. Each comment is labeled as one of two categories: positive or negative. The labels are determined based on the ratings: comments with ratings of 7 or higher are classified as positive, comments rated 4 or lower are classified as negative, and comments with a rating of 5 or 6 are excluded from the dataset to maintain a clear distinction between the two sentiment categories. This dataset is widely utilized in sentiment analysis research for evaluating model performance on movie review classification tasks.
The AG’s News dataset is a news categorization dataset created by Zhang [31]. It encompasses news articles organized into four distinct categories: business, sci/tech, sports, and world. Each category consists of a substantial number of samples: specifically, 120,000 training samples and 7600 testing samples, resulting in a cumulative total of 127,600 samples. To manage computational resources and to maintain a balanced sub-sample, we randomly selected a subset of 50,000 samples for our experiments. This dataset is commonly used to assess the performance of text classification algorithms on diverse news topics.
The DBpedia Ontology Classification dataset is a comprehensive text classification dataset provided by Zhang [31]. This dataset involves 14 diverse categories, which include people, places, organizations, and animals, among others. Each category is represented by approximately 10,000 article titles and summaries sourced from Wikipedia. For the purpose of our experiments, a random subset of 50,000 samples is selected. This dataset is particularly useful for evaluating the capability of classification models to distinguish between a wide range of semantic categories based on short textual descriptions.
The SST-2 (Stanford Sentiment Treebank) dataset [32] is specifically designed for sentiment analysis and focuses on binary classification. It involves categorizing text into positive and negative sentiment categories. The dataset comprises 10,662 samples, which are divided into 8544 training samples and 2118 testing samples. Each text sample within the SST-2 dataset originates from movie reviews and is labeled with a corresponding sentiment tag indicating either positive or negative sentiment. The annotation process for each sentence was conducted by approximately 10 human annotators, which ensures a high degree of accuracy and consistency in the sentiment labels. This dataset is used extensively for benchmarking sentiment analysis models due to its reliable and well-annotated nature.

4.1.2. Baseline Models

The proposed KRA methodology serves as a versatile augmentation technique that can be integrated with a broad spectrum of deep learning models, especially those employed in text classification tasks. To evaluate the effectiveness of our approach, we select several widely recognized models as baselines.
CNN and LSTM: We employ the convolutional neural network (CNN) model as described by Kim [6] and the long short-term memory (LSTM) network detailed by Liu [5]. Both models are initialized using pre-trained GloVe word embeddings. For these models, the final hidden states are extracted and stored as the neighbor vectors.
BERT and RoBERTa: We utilize the bert-base-uncased model from Devlin et al. [8] and the RoBERTa-base model from Liu et al. [9]. For both BERT and RoBERTa, the outputs from the last layer are saved as neighbor vectors. The hidden states from these final layers are used as text representations, which are subsequently fed into a fully connected layer to perform text classification.
Our selection of these models ensures a comprehensive evaluation of the KRA method across different types of architecture, thereby demonstrating its generalizability and robustness.

4.1.3. Settings

For the BERT and RoBERTa models, we save their outputs locally and utilize a fully connected layer to reduce the dimensionality of the embeddings. Text samples exceeding 510 tokens are truncated to the first 510 tokens to accommodate the input limitations for both the BERT and RoBERTa models. For the LSTM model, we set both the embedding size and hidden size to 64. The CNN model employs three filters of sizes 3, 10, and 25, with each convolutional block containing 100 filters. Both the LSTM and CNN models leverage pre-trained GloVe embeddings with an embedding size of 100.
We conduct the experiments on a server with the following configuration: an Intel Xeon Platinum 8260 CPU, two NVIDIA RTX 2080 Ti GPUs with 11 GB of VRAM each, and 120 GB of RAM. The operating system used is Ubuntu 18.04, with CUDA version 11.3 and PyTorch version 1.11.
To augment the training dataset, we employ the easy data augmentation (EDA) techniques described by Wei and Zou [25] as well as GPT-3.5. The augmented training dataset is pre-saved for later retrieval. Each text sample in the training dataset undergoes six augmentation methods (LLM, SR, RI, RS, RD, and BT), resulting in five samples augmented via LLM and five through traditional methods.
The parameter K, which represents the number of neighbors to be identified, is a critical hyperparameter. Based on preliminary experiments, we found that K = 32 consistently provides a good balance between performance and search time, making it an effective choice for subsequent experiments. For the regularization parameter λ , a value of 0.6 is consistently used in our experiments, although more granular tuning could potentially yield enhanced results.
Model parameters are optimized using the Adam algorithm [33] with an initial learning rate of 0.001 and a batch size of 128. This setup ensures robust training performance across the different models and augmentation techniques used in our experiments.

4.2. Experimental Results

4.2.1. Test Accuracy

Table 3 shows the four baseline models and the overall test results of the models with the addition of our method. From the results, we can see that our method achieves good performance on BERT, RoBERTa, CNN, and LSTM models and outperforms most baseline models on most of the datasets. Our method provides the most significant improvement for BERT and RoBERTa and outperforms the baseline method on all six datasets. Our method also shows excellent performance on the CNN model and LSTM model, outperforming the baseline models on almost all datasets. It is only slightly worse on the IMDB dataset, but it also exceeds the baseline on two pre-trained models. In terms of datasets, our method improves the performance of all four models on the MR, MPQA, AG’s News dataset, SST-2 dataset, and DBPedia dataset. The overall results verify the effectiveness and universality of the method.
Our method achieves good results on both the two pre-trained models, BERT and RoBERTa, with an average improvement of 0.3% in accuracy across all datasets. In particular, our method achieves the best results on RoBERTa, with an average improvement of about 0.4% in accuracy. This is because BERT and RoBERTa, as pre-trained models, have better text encoding capabilities and effects than CNNs and LSTM networks, which allows the KNN retriever to match neighbors that are more similar to the target text. For the CNN and LSTM models on the IMDB dataset, our method shows lower performance than the baseline method. This may be due to the relatively weak ability of CNNs and LSTM networks to understand long texts, and the IMDB dataset has a much greater average length compared to the other datasets. Consequently, it is challenging to find vectors similar to the target text in the training dataset. However, on the other five datasets, our method significantly improves the performance of the models.
Our method demonstrates superior performance on datasets with shorter average text lengths, such as MR, MPQA, AG’s News, SST-2, and DBPedia, compared to the IMDB dataset. The results, as shown in Table 3, indicate that KRA achieves significant improvements over the baseline models on these five datasets. This trend suggests that text augmentation methods might be more effective at enhancing shorter texts, potentially due to the ease of generating high-quality augmented samples for such texts. Shorter texts may have simpler structures and more focused content, making it easier for augmentation techniques to produce relevant and coherent variations that can aid in improving model performance. On the other hand, the IMDB dataset, characterized by its significantly longer average text length, poses a greater challenge for text augmentation methods. The complexity and length of the texts in IMDB make it harder to generate high-quality augmented samples that are both contextually coherent and semantically relevant. Consequently, the benefits of our KRA method are less pronounced on the IMDB dataset, especially for models with relatively weaker text encoding capabilities such as CNNs and LSTM networks. Nevertheless, KRA still manages to enhance the performance of the pre-trained models BERT and RoBERTa on this dataset, albeit to a lesser extent compared to the shorter-text datasets.

4.2.2. Ablation Study

Table 4 shows the results obtained by using the KNN algorithm with or without text augmentation. We test the baseline model and the model using only the KNN classifier on six datasets, and the results are inferior to those of the model with KRA added. From the table, we can see that the KNN method can also achieve some improvement on some datasets, such as the IMDB and AG datasets using BERT as the baseline. However, in some cases, it does not perform well and may even have a negative effect on the results. This is possibly due to the limitation of the KNN algorithm, which requires high-quality retrieval sets. In contrast, the results obtained by using text augmentation in our method are generally better than those obtained by using only the KNN algorithm, indicating the importance and effectiveness of the text augmentation module in our method. This proves that both the KNN module and the text augmentation module in our approach are indispensable.
Table 5 shows the classification accuracy on six datasets using BERT, comparing the performance of using only traditional methods, only LLM-based method, and a combination of both. It is evident that the combination of traditional and LLM-based methods consistently outperforms the individual methods across all datasets. The combined approach achieves significant improvements over using traditional methods or LLM-based methods alone. This improvement can be attributed to the complementary nature of the two augmentation methods in terms of information gain. Traditional text augmentation methods (such as synonym replacement and sentence reordering) introduce diverse textual variations, increasing the richness of the data. Meanwhile, LLMs tend to generate longer and more complex sentences, providing contextually richer and semantically varied text. By combining both, the KRA method can more effectively extract and utilize the semantic information from the training data, thereby enhancing the model’s performance during the inference stage.

4.2.3. Hyperparameter Analysis

K and λ are two important hyperparameters in our method, and adjusting them can affect the final performance of the method.
K is the number of neighbors, and in general, a larger K leads to better performance in the KNN algorithm. However, experiments have shown that the size of K does not have a significant impact on the effectiveness of the model. We conducted experiments on the AG’s News dataset using LSTM with different K values, and the results are shown in Figure 2. It can be seen that the best performance is achieved when K = 256 , but the performance drops when K > 256 . K = 32 is a relatively economical choice because it achieves good performance while reducing retrieval resources.
The parameter λ adjusts the influence of the KNN classifier’s results relative to the linear classifier’s results, and it is also important because we combine linear classifiers and KNN classifiers. We conducted experiments on the AG’s News dataset using LSTM with different λ values, and the results in Figure 2 show the impact of different λ values on the experimental results. When λ = 0 , it means that the linear classifier makes all the decisions, and when λ = 1 , it means that the KNN classifier makes all the decisions. However, neither of these results is as good as when λ = 0.6 , indicating that the combination of the two classifiers produces complementary effects. Although the model performs best when λ = 0.6 , this does not imply that the results of the KNN classifier are more important than those of the linear classifier alone. It is only through their combination that the best performance can be achieved.

4.2.4. The Impact of Model Fine-Tuning

In previous experiments, for achieving optimal results, the pre-trained BERT and RoBERTa models we utilized were fine-tuned on downstream datasets. However, we wonder whether fine-tuning would impact the results. Therefore, we conducted experiments using BERT on the AG’s News and DBPedia datasets, with the results depicted in Table 6. From the table, we observe that our method still enhances the performance of frozen BERT by an average of 0.2% on the two datasets. This demonstrates that our proposed approach does not require additional fine-tuning or training specifically for downstream tasks when applied to pre-trained models like BERT. This is because pre-trained models such as BERT and RoBERTa have already been trained on extensive corpora and possess strong text encoding capabilities. The superior text encoding facilitates our method with retrieving more similar neighbors, thereby effectively improving classification performance.
To further illustrate the consistent effectiveness of our method during the fine-tuning process, we plotted the accuracy of both the baseline BERT and BERT with KRA on the AG’s News dataset across different training steps, as shown in Figure 3. From the figure, it is evident that our method consistently outperforms the baseline BERT model at every training step. Notably, the accuracy of BERT + KRA starts higher and maintains a clear margin over the baseline BERT throughout the entire training process. Additionally, BERT+KRA reaches optimal performance around the 5200th training step and continues to show slight improvements thereafter, achieving higher final accuracy compared to the baseline BERT. This consistent improvement across all training steps indicates that our method is robust and effective in enhancing the performance of pre-trained models without the need for extensive fine-tuning. The KRA method leverages the powerful text encoding capabilities of pre-trained models to retrieve semantically similar neighbors, thereby augmenting the training data in a meaningful way and improving classification performance. The experimental results demonstrate that the KRA method provides continuous and reliable improvements over baseline models, making it a valuable enhancement for a wide range of text classification tasks.

4.2.5. Computational Cost

We measured the inference time using BERT on several datasets in our experimental environment. From the results shown in Table 7, we can see that the overhead introduced by our method is minimal for datasets with relatively smaller training sets, such as MR and MPQA. However, for larger datasets like AG’s News, our method does lead to some additional time overhead.
We make several efforts to optimize algorithm overhead. Firstly, we pre-augment the datasets and store the augmented data, which eliminates the need for real-time text augmentation. Secondly, we use FAISS’s IndexFlatL2 index to accelerate vector retrieval. In terms of real-time inference, our method typically incurs only a few milliseconds of additional overhead. Specifically, for the larger AG’s News dataset, the average inference time with our method is about 26 ms, compared to 18 ms with the conventional method, resulting in an additional delay of approximately 8 ms.
Overall, while our approach introduces some additional computational cost, particularly for larger datasets, the trade-off in terms of improved classification accuracy and leveraging semantic information from the training set during the inference phase is well justified. Future work will focus on further optimizing the algorithm to reduce the overhead without compromising the performance benefits.

5. Conclusions and Future Work

In this paper, we proposed a novel method called KRA, which integrates a classic k-nearest neighbors (KNN) classifier with a text augmentation module to enhance the performance of pre-trained language models on text classification tasks. The core ideas behind our approach are to leverage the semantic information of the training dataset during the inference phase and to improve the quality of the KNN retrieval set through text augmentation.
Traditional deep learning models, including pre-trained language models like BERT and RoBERTa, do not utilize the training dataset during the testing phase. This can be likened to a closed-book exam where the model cannot refer back to the training data. By incorporating KNN, our method effectively transforms the testing phase into an open-book exam, allowing the model to reference similar examples from the training dataset. This helps with making more informed and accurate classification decisions. The effectiveness of the KNN algorithm heavily depends on the quality and diversity of the retrieval set. Given that the training dataset might not always be comprehensive, we employ text augmentation techniques to enhance the dataset. By expanding the retrieval set with augmented samples, we ensure that the KNN algorithm has a richer and more diverse set of neighbors to draw from, thereby improving its performance.
Our extensive experiments demonstrate the efficacy of the proposed KRA method. The empirical results show that KRA consistently improves the performance of various models, including BERT, RoBERTa, CNN, and LSTM, across multiple datasets. Notably, our method achieves significant improvements on datasets with shorter average text lengths, such as MR, MPQA, AG’s News, SST-2, and DBPedia. Although the benefits are less pronounced on the IMDB dataset, which is characterized by longer texts, KRA still manages to enhance the performance of pre-trained models like BERT and RoBERTa. Furthermore, the ablation study highlights the importance of the text augmentation module in our approach, showing that the combination of traditional and LLM-based augmentation methods consistently yields better results than using either method alone. The hyperparameter analysis indicates that the best performance is achieved with a suitable balance between the number of neighbors (K) and the weight ratio ( λ ) of the KNN classifier’s results.
In future work, we will further explore text augmentation methods that are more suitable for the KNN algorithm to further improve the model’s performance. In terms of neighbor search in the KNN algorithm, searching neighbors from different dimensions or conducting multiple neighbor searches may enhance the precision of retrieved neighbors, thereby reducing the probability of finding invalid neighbors. From the perspective of domain transfer, KRA could also be applicable to multi-label classification tasks or few-shot classification. These are directions we intend to explore and delve into in the future. Further optimization of hyperparameters for each specific situation or task is also a goal we aim to study.

Author Contributions

Conceptualization, J.L.; methodology, J.L., C.T. and Z.L.; software, C.T.; validation: Z.L. and Y.Z.; formal analysis: J.L. and C.T.; investigation: J.L.; resources: J.L. and C.T.; data curation: Y.Z.; writing—original draft preparation: J.L.; writing—review and editing: Z.L. and L.H.; visualization: X.L. and Y.Y.; supervision: Y.Y.; project administration: R.P.; funding acquisition: Y.Y. All authors have read and agreed to the published version of the manuscript.

Funding

This research is funded by National Natural Science Foundation of China (U22B2019).

Data Availability Statement

The raw data supporting the conclusions of this article will be made available by the authors on request.

Conflicts of Interest

The authors declare no conflicts of interest.

References

  1. Li, Q.; Peng, H.; Li, J.; Xia, C.; Yang, R.; Sun, L.; Yu, P.S.; He, L. A Survey on Text Classification: From Traditional to Deep Learning. ACM Trans. Intell. Syst. Technol. 2022, 13, 1–41. [Google Scholar] [CrossRef]
  2. Rao, S.; Verma, A.K.; Bhatia, T. A review on social spam detection: Challenges, open issues, and future directions. Expert Syst. Appl. 2021, 186, 115742. [Google Scholar] [CrossRef]
  3. Nandwani, P.; Verma, R. A review on sentiment analysis and emotion detection from text. Soc. Netw. Anal. Min. 2021, 11, 81. [Google Scholar] [CrossRef]
  4. Ji, S.; Pan, S.; Cambria, E.; Marttinen, P.; Philip, S.Y. A survey on knowledge graphs: Representation, acquisition, and applications. IEEE Trans. Neural Netw. Learn. Syst. 2021, 33, 494–514. [Google Scholar] [CrossRef]
  5. Liu, P.; Qiu, X.; Huang, X. Recurrent Neural Network for Text Classification with Multi-Task Learning. In Proceedings of the Twenty-Fifth International Joint Conference on Artificial Intelligence, New York, NY, USA, 9–15 July 2016; AAAI Press: Washington, DC, USA, 2016. IJCAI’16. pp. 2873–2879. [Google Scholar]
  6. Kim, Y. Convolutional Neural Networks for Sentence Classification. In Proceedings of the 2014 Conference on Empirical Methods in Natural Language Processing (EMNLP), Doha, Qatar, 25–29 October 2014; Association for Computational Linguistics: Doha, Qatar, 2014; pp. 1746–1751. [Google Scholar] [CrossRef]
  7. Vaswani, A.; Shazeer, N.; Parmar, N.; Uszkoreit, J.; Jones, L.; Gomez, A.N.; Kaiser, L.; Polosukhin, I. Attention is All You Need. In Proceedings of the 31st International Conference on Neural Information Processing Systems, Long Beach, CA, USA, 4–9 December 2017; Curran Associates Inc.: Red Hook, NY, USA, 2017. NIPS’17. pp. 6000–6010. [Google Scholar]
  8. Devlin, J.; Chang, M.W.; Lee, K.; Toutanova, K. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), Minneapolis, MN, USA, 2–7 June 2019; Association for Computational Linguistics: Minneapolis, MN, USA, 2019; pp. 4171–4186, N19-1423. [Google Scholar] [CrossRef]
  9. Liu, Y.; Ott, M.; Goyal, N.; Du, J.; Joshi, M.; Chen, D.; Levy, O.; Lewis, M.; Zettlemoyer, L.; Stoyanov, V. Roberta: A robustly optimized bert pretraining approach. arXiv 2019, arXiv:1907.11692. [Google Scholar]
  10. Yang, Z.; Dai, Z.; Yang, Y.; Carbonell, J.; Salakhutdinov, R.; Le, Q.V. XLNet: Generalized Autoregressive Pretraining for Language Understanding. In Proceedings of the 33rd International Conference on Neural Information Processing Systems, Vancouver, BC, Canada, 8–14 December 2019; Curran Associates Inc.: Red Hook, NY, USA, 2019. [Google Scholar]
  11. Brown, T.B.; Mann, B.; Ryder, N.; Subbiah, M.; Kaplan, J.; Dhariwal, P.; Neelakantan, A.; Shyam, P.; Sastry, G.; Askell, A.; et al. Language Models Are Few-Shot Learners. In Proceedings of the 34th International Conference on Neural Information Processing Systems, Online, 6–12 December 2020; Curran Associates Inc.: Red Hook, NY, USA, 2020. NIPS’20. [Google Scholar]
  12. Li, H.; Su, Y.; Cai, D.; Wang, Y.; Liu, L. A survey on retrieval-augmented text generation. arXiv 2022, arXiv:2202.01110. [Google Scholar]
  13. Li, J.; Galley, M.; Brockett, C.; Gao, J.; Dolan, B. A Diversity-Promoting Objective Function for Neural Conversation Models. In Proceedings of the 2016 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, San Diego, CA, USA, 12–17 June 2016; Association for Computational Linguistics: San Diego, CA, USA, 2016; pp. 110–119. [Google Scholar] [CrossRef]
  14. Zheng, X.; Zhang, Z.; Guo, J.; Huang, S.; Chen, B.; Luo, W.; Chen, J. Adaptive nearest neighbor machine translation. arXiv 2021, arXiv:2105.13022. [Google Scholar]
  15. Khandelwal, U.; Levy, O.; Jurafsky, D.; Zettlemoyer, L.; Lewis, M. Generalization through memorization: Nearest neighbor language models. arXiv 2019, arXiv:1911.00172. [Google Scholar]
  16. Wang, S.; Li, X.; Meng, Y.; Zhang, T.; Ouyang, R.; Li, J.; Wang, G. k NN-NER: Named Entity Recognition with Nearest Neighbor Search. arXiv 2022, arXiv:2203.17103. [Google Scholar]
  17. Borgeaud, S.; Mensch, A.; Hoffmann, J.; Cai, T.; Rutherford, E.; Millican, K.; Van Den Driessche, G.B.; Lespiau, J.B.; Damoc, B.; Clark, A.; et al. Improving language models by retrieving from trillions of tokens. In Proceedings of the International Conference on Machine Learning PMLR, Baltimore, MD, USA, 17–23 July 2022; pp. 2206–2240. [Google Scholar]
  18. Liu, G.; Guo, J. Bidirectional LSTM with attention mechanism and convolutional layer for text classification. Neurocomputing 2019, 337, 325–338. [Google Scholar] [CrossRef]
  19. Shen, D.; Zhang, Y.; Henao, R.; Su, Q.; Carin, L. Deconvolutional latent-variable model for text sequence matching. In Proceedings of the AAAI Conference on Artificial Intelligence, New Orleans, LA, USA, 2–7 February 2018; Volume 32. [Google Scholar]
  20. Kassner, N.; Schütze, H. BERT-kNN: Adding a kNN search component to pretrained language models for better QA. arXiv 2020, arXiv:2005.00766. [Google Scholar]
  21. Li, L.; Song, D.; Ma, R.; Qiu, X.; Huang, X. KNN-BERT: Fine-tuning pre-trained models with KNN classifier. arXiv 2021, arXiv:2110.02523. [Google Scholar]
  22. Yu, A.W.; Dohan, D.; Luong, M.T.; Zhao, R.; Chen, K.; Norouzi, M.; Le, Q.V. Qanet: Combining local convolution with global self-attention for reading comprehension. arXiv 2018, arXiv:1804.09541. [Google Scholar]
  23. Xie, Z.; Wang, S.I.; Li, J.; Lévy, D.; Nie, A.; Jurafsky, D.; Ng, A.Y. Data noising as smoothing in neural network language models. arXiv 2017, arXiv:1703.02573. [Google Scholar]
  24. Kobayashi, S. Contextual augmentation: Data augmentation by words with paradigmatic relations. arXiv 2018, arXiv:1805.06201. [Google Scholar]
  25. Wei, J.; Zou, K. Eda: Easy data augmentation techniques for boosting performance on text classification tasks. arXiv 2019, arXiv:1901.11196. [Google Scholar]
  26. Dai, H.; Liu, Z.; Liao, W.; Huang, X.; Cao, Y.; Wu, Z.; Zhao, L.; Xu, S.; Liu, W.; Liu, N.; et al. Auggpt: Leveraging chatgpt for text data augmentation. arXiv 2023, arXiv:2302.13007. [Google Scholar]
  27. Hu, L.; He, H.; Wang, D.; Zhao, Z.; Shao, Y.; Nie, L. LLM vs Small Model? Large Language Model Based Text Augmentation Enhanced Personality Detection Model. Proc. AAAI Conf. Artif. Intell. 2024, 38, 18234–18242. [Google Scholar] [CrossRef]
  28. Pang, B.; Lee, L.; Vaithyanathan, S. Thumbs up? Sentiment classification using machine learning techniques. arXiv 2002, arXiv:cs/0205070. [Google Scholar]
  29. Stoyanov, V.; Cardie, C.; Wiebe, J. Multi-Perspective Question Answering Using the OpQA Corpus. In Proceedings of the Human Language Technology Conference and Conference on Empirical Methods in Natural Language Processing, Vancouver, BC, Canada, 6–8 October 2005; Mooney, R., Brew, C., Chien, L.F., Kirchhoff, K., Eds.; Association for Computational Linguistics: Vancouver, BC, Canada, 2005; pp. 923–930. [Google Scholar]
  30. Maas, A.L.; Daly, R.E.; Pham, P.T.; Huang, D.; Ng, A.Y.; Potts, C. Learning Word Vectors for Sentiment Analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies, Portland, OR, USA, 19–24 June 2011; Association for Computational Linguistics: Portland, OR, USA, 2011; pp. 142–150. [Google Scholar]
  31. Zhang, X.; Zhao, J.; LeCun, Y. Character-Level Convolutional Networks for Text Classification. In Proceedings of the 28th International Conference on Neural Information Processing Systems—Volume 1, Montreal, QC, Canada, 7–12 December 2015; MIT Press: Cambridge, MA, USA, 2015. NIPS’15. pp. 649–657. [Google Scholar]
  32. Socher, R.; Perelygin, A.; Wu, J.; Chuang, J.; Manning, C.D.; Ng, A.; Potts, C. Recursive Deep Models for Semantic Compositionality Over a Sentiment Treebank. In Proceedings of the 2013 Conference on Empirical Methods in Natural Language Processing, Seattle, WA, USA, 18–21 October 2013; Association for Computational Linguistics: Seattle, WA, USA, 2013; pp. 1631–1642. [Google Scholar]
  33. Kingma, D.P.; Ba, J. Adam: A method for stochastic optimization. arXiv 2014, arXiv:1412.6980. [Google Scholar]
Figure 1. The structure of our proposed method. The embeddings of the text in the training set are saved in the datastore as key–value pairs along with their labels. During the inference phase, the text is transformed into embeddings by the model, and then, the k-nearest neighbors are searched for in the datastore. Text augmentation is applied to these neighbors, and the KNN prediction is calculated based on their labels. The final prediction is determined by a weighted calculation combining the KNN prediction with the model prediction.
Figure 1. The structure of our proposed method. The embeddings of the text in the training set are saved in the datastore as key–value pairs along with their labels. During the inference phase, the text is transformed into embeddings by the model, and then, the k-nearest neighbors are searched for in the datastore. Text augmentation is applied to these neighbors, and the KNN prediction is calculated based on their labels. The final prediction is determined by a weighted calculation combining the KNN prediction with the model prediction.
Electronics 13 03237 g001
Figure 2. Hyperparameter analysis of KRA, including the number of neighbors K and the ratio λ .
Figure 2. Hyperparameter analysis of KRA, including the number of neighbors K and the ratio λ .
Electronics 13 03237 g002
Figure 3. Test accuracy on AG’s News over different fine-tuning steps.
Figure 3. Test accuracy on AG’s News over different fine-tuning steps.
Electronics 13 03237 g003
Table 1. Sentences generated using different text augmentation methods. LLM: LLM generation. SR: synonym replacement. RI: random insertion. RS: random swap. RD: random deletion. BT: back translation.
Table 1. Sentences generated using different text augmentation methods. LLM: LLM generation. SR: synonym replacement. RI: random insertion. RS: random swap. RD: random deletion. BT: back translation.
MethodSentence
OriginalLuka Doncic is a famous Slovenian basketball player.
1. Renowned for his skills, Luka Doncic is a basketball player from Slovenia.
2. Slovenian basketball star Luka Doncic is widely celebrated in the sport.
LLM3. Hailing from Slovenia, Luka Doncic has gained fame as a basketball player.
4. Luka Doncic, a celebrated figure in basketball, proudly represents Slovenia.
5. The fame of Slovenian basketball player Luka Doncic extends far and wide.
SRLuka Doncic is a well-known Slovenian basketball player.
RILuka Doncic is a famous Slovenian basketball happiness player.
RSLuka Doncic is a basketball Slovenian famous player.
RDLuka Doncic is a famous Slovenian player.
BTLuka Doncic is a famous basketball player from Slovenia.
Table 2. Prompt using GPT-3.5 to generate augmented texts.
Table 2. Prompt using GPT-3.5 to generate augmented texts.
TypePrompt
SystemYou are a helpful assistant that rewrites texts and makes sentences smooth,
meaningful, and concise.
UserI will give you a sample, please rewrite it, then give me 5 rewrite answers.
AssistantSure, please provide the sample text that you’d like me to rewrite.
User{text}
Table 3. Test accuracy on six text classification datasets. We run all models and report their mean ± standard deviations. The bold means a significant improvement over baseline methods based on the t-test (p < 0.05).
Table 3. Test accuracy on six text classification datasets. We run all models and report their mean ± standard deviations. The bold means a significant improvement over baseline methods based on the t-test (p < 0.05).
ModelMRMPQAIMDBAG’s NewsSST-2DBPedia
BERT0.8874 ± 0.00340.8793 ± 0.00620.9523 ± 0.00290.9461 ± 0.00180.9115 ± 0.00730.9743 ± 0.0034
BERT + KRA0.8883 ± 0.00790.8845 ± 0.00260.9532 ± 0.00460.9512 ± 0.01210.9124 ± 0.00930.9764 ± 0.0070
RoBERTa0.9070 ± 0.00330.8835 ± 0.00420.9568 ± 0.00520.9480 ± 0.02350.9383 ± 0.00860.9786 ± 0.0063
RoBERTa + KRA0.9099 ± 0.00870.8875 ± 0.00850.9603 ± 0.00620.9545 ± 0.01370.9423 ± 0.00520.9799 ± 0.0064
CNN0.8078 ± 0.00460.8647 ± 0.00650.8657 ± 0.00360.8969 ± 0.00190.8656 ± 0.00090.9566 ± 0.0055
CNN + KRA0.8103 ± 0.00420.8675 ± 0.00920.8649 ± 0.00670.8974 ± 0.00600.8671 ± 0.00530.9572 ± 0.0086
LSTM0.8196 ± 0.00650.8665 ± 0.00540.8832 ± 0.00750.8793 ± 0.00260.8545 ± 0.00120.9342 ± 0.0037
LSTM + KRA0.8221 ± 0.00530.8688 ± 0.01350.8828 ± 0.00340.8801 ± 0.00280.8570 ± 0.01540.9357 ± 0.0110
Table 4. Ablation study: comparison of accuracy with and without text augmentation module on six datasets using BERT and RoBERTa. The bold indicates the best performance.
Table 4. Ablation study: comparison of accuracy with and without text augmentation module on six datasets using BERT and RoBERTa. The bold indicates the best performance.
ModelMRMPQAIMDBAG’s NewsSST-2DBPedia
BERT0.88740.87930.95230.94610.91150.9743
BERT + KNN0.88760.88220.95280.94760.91170.9735
BERT + KRA0.88830.88450.95320.95120.91240.9764
RoBERTa0.90700.88350.95680.94800.93830.9786
RoBERTa + KNN0.90790.88590.95840.95100.94060.9779
RoBERTa + KRA0.90990.88750.96030.95450.94230.9799
Table 5. Ablation study: comparison of accuracy between traditional methods and LLM-based method using BERT. The bold indicates the best performance.
Table 5. Ablation study: comparison of accuracy between traditional methods and LLM-based method using BERT. The bold indicates the best performance.
MethodMRMPQAIMDBAG’s NewsSST-2DBPedia
Traditional0.88630.88260.95350.94900.91100.9759
LLM-based0.88740.88310.95260.95100.91180.9750
TRA + LLM0.88830.88450.95320.95120.91240.9764
Table 6. Test accuracy on two datasets using frozen BERT as the baseline.
Table 6. Test accuracy on two datasets using frozen BERT as the baseline.
ModelAG’s NewsDBPedia
frozen BERT0.91300.9711
frozen BERT + KRA0.91650.9723
Table 7. Inference time with and without KRA.
Table 7. Inference time with and without KRA.
ModelMRMPQAIMDBAG’s News
BERT19.2 s18.9 s452.3 s137.9 s
BERT + KRA20.1 s19.8 s491.8 s205.9 s
Disclaimer/Publisher’s Note: The statements, opinions and data contained in all publications are solely those of the individual author(s) and contributor(s) and not of MDPI and/or the editor(s). MDPI and/or the editor(s) disclaim responsibility for any injury to people or property resulting from any ideas, methods, instructions or products referred to in the content.

Share and Cite

MDPI and ACS Style

Li, J.; Tang, C.; Lei, Z.; Zhang, Y.; Li, X.; Yu, Y.; Pi, R.; Hu, L. KRA: K-Nearest Neighbor Retrieval Augmented Model for Text Classification. Electronics 2024, 13, 3237. https://doi.org/10.3390/electronics13163237

AMA Style

Li J, Tang C, Lei Z, Zhang Y, Li X, Yu Y, Pi R, Hu L. KRA: K-Nearest Neighbor Retrieval Augmented Model for Text Classification. Electronics. 2024; 13(16):3237. https://doi.org/10.3390/electronics13163237

Chicago/Turabian Style

Li, Jie, Chang Tang, Zhechao Lei, Yirui Zhang, Xuan Li, Yanhua Yu, Renjie Pi, and Linmei Hu. 2024. "KRA: K-Nearest Neighbor Retrieval Augmented Model for Text Classification" Electronics 13, no. 16: 3237. https://doi.org/10.3390/electronics13163237

Note that from the first issue of 2016, this journal uses article numbers instead of page numbers. See further details here.

Article Metrics

Back to TopTop