# !/usr/bin/python
# -*- coding:utf-8 -*-
# File Name: text_encode.py
# Author: Keming Lu
# Mail: keminglu@usc.edu
# Created Time: 2022-03-22 02:17:37
import torch
import numpy as np
from transformers import AutoModel
from transformers import AutoTokenizer
from tqdm import tqdm
= {
config_dict "coder": "GanjinZero/coder_eng",
"coder_all": "GanjinZero/coder_all",
"coder_old": "GanjinZero/UMLSBert_ENG",
"coder_pp": "GanjinZero/coder_eng_pp",
"sapbert": "cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
"pubmedbert": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"biobert": "monologg/biobert_v1.1_pubmed",
"bert": "bert-base-uncased"
}
class TextEncoder(object):
def __init__(
self,
model_name,
device,= False
run_check_max_length
):self.device = device
self.model_name = model_name
self.config = config_dict[self.model_name]
self.tokenizer = AutoTokenizer.from_pretrained(self.config)
self.model = AutoModel.from_pretrained(self.config).to(self.device)
self.run_check_max_length = run_check_max_length
self.pred_batch_size = 64
self.max_length = 32 #128
def check_max_length(self, inputs):
= 0
cnt for each in inputs:
= self.tokenizer.encode_plus(each, add_special_tokens=True, padding='do_not_pad')
ids if len(ids) > self.max_length:
+= 1
cnt = len(inputs)
all_cnt print(f"Current max length is {self.max_length}.")
print(f"{(1 - cnt / all_cnt)*100}% samples can fit in length {self.max_length}.")
def get_embed(
self,
phrase_list,=True,
normalize="CLS"
summary_method
):if self.run_check_max_length:
self.check_max_length(phrase_list)
= []
input_ids = tqdm(total=len(phrase_list))
pbar "Tokenizing phrases:")
pbar.set_description(for phrase in phrase_list:
self.tokenizer.encode_plus(
input_ids.append(=self.max_length, add_special_tokens=True,
phrase, max_length=True, padding='max_length')['input_ids'])
truncation1)
pbar.update(self.model.eval()
= len(input_ids)
count = 0
now_count = tqdm(total=count)
pbar "Encoding:")
pbar.set_description(with torch.no_grad():
while now_count < count:
= torch.LongTensor(input_ids[now_count:min(
input_gpu_0 + self.pred_batch_size, count)]).to(self.device)
now_count if summary_method == "CLS":
= self.model(input_gpu_0)[1]
embed if summary_method == "MEAN":
= torch.mean(self.model(input_gpu_0)[0], dim=1)
embed if normalize:
= torch.norm(
embed_norm =2, dim=1, keepdim=True).clamp(min=1e-12)
embed, p= embed / embed_norm
embed = embed.cpu().detach().numpy()
embed_np if now_count == 0:
= embed_np
output else:
= np.concatenate((output, embed_np), axis=0)
output = min(now_count + self.pred_batch_size, count)
now_count self.pred_batch_size)
pbar.update(return output
if __name__ == "__main__":
= ["abs"]
phrases = torch.device("cuda:0")
device = TextEncoder("pubmedbert", device)
encoder print(encoder.get_embed(phrases).shape)
text_encode
Overview
The text_encode.py
script provides a framework for encoding text phrases into vector embeddings using pre-trained models from Hugging Face’s transformers
library. It is designed for biomedical and general-purpose text encoding, with support for several pre-trained models such as PubMedBERT
, SapBERT
, and BERT
. This script is especially useful for applications in natural language processing (NLP) tasks like similarity computation and classification.
Key Components
config_dict
A dictionary mapping model names to their respective Hugging Face identifiers. This allows for flexibility in switching between different models.
TextEncoder
Class
A class for encoding text into vector embeddings. It includes the following features:
Initialization:
- Loads the specified pre-trained model and tokenizer.
- Configures device settings (e.g., CPU or GPU).
- Allows optional checking of maximum input length for tokenization.
Methods:
check_max_length(inputs)
: Analyzes if input phrases exceed the maximum allowable length for the model.get_embed(phrase_list, normalize=True, summary_method="CLS")
: Encodes a list of phrases into embeddings using the specified summary method and normalization.
Batch Processing:
- Processes inputs in chunks to optimize memory usage and performance.
Output:
- Returns embeddings as NumPy arrays.