# !/usr/bin/python
# -*- coding:utf-8 -*-
# File Name: text_encode.py
# Author: Keming Lu
# Mail: keminglu@usc.edu
# Created Time: 2022-03-22 02:17:37
import torch
import numpy as np
from transformers import AutoModel
from transformers import AutoTokenizer
from tqdm import tqdm
config_dict = {
"coder": "GanjinZero/coder_eng",
"coder_all": "GanjinZero/coder_all",
"coder_old": "GanjinZero/UMLSBert_ENG",
"coder_pp": "GanjinZero/coder_eng_pp",
"sapbert": "cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
"pubmedbert": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
"biobert": "monologg/biobert_v1.1_pubmed",
"bert": "bert-base-uncased"
}
class TextEncoder(object):
def __init__(
self,
model_name,
device,
run_check_max_length = False
):
self.device = device
self.model_name = model_name
self.config = config_dict[self.model_name]
self.tokenizer = AutoTokenizer.from_pretrained(self.config)
self.model = AutoModel.from_pretrained(self.config).to(self.device)
self.run_check_max_length = run_check_max_length
self.pred_batch_size = 64
self.max_length = 32 #128
def check_max_length(self, inputs):
cnt = 0
for each in inputs:
ids = self.tokenizer.encode_plus(each, add_special_tokens=True, padding='do_not_pad')
if len(ids) > self.max_length:
cnt += 1
all_cnt = len(inputs)
print(f"Current max length is {self.max_length}.")
print(f"{(1 - cnt / all_cnt)*100}% samples can fit in length {self.max_length}.")
def get_embed(
self,
phrase_list,
normalize=True,
summary_method="CLS"
):
if self.run_check_max_length:
self.check_max_length(phrase_list)
input_ids = []
pbar = tqdm(total=len(phrase_list))
pbar.set_description("Tokenizing phrases:")
for phrase in phrase_list:
input_ids.append(self.tokenizer.encode_plus(
phrase, max_length=self.max_length, add_special_tokens=True,
truncation=True, padding='max_length')['input_ids'])
pbar.update(1)
self.model.eval()
count = len(input_ids)
now_count = 0
pbar = tqdm(total=count)
pbar.set_description("Encoding:")
with torch.no_grad():
while now_count < count:
input_gpu_0 = torch.LongTensor(input_ids[now_count:min(
now_count + self.pred_batch_size, count)]).to(self.device)
if summary_method == "CLS":
embed = self.model(input_gpu_0)[1]
if summary_method == "MEAN":
embed = torch.mean(self.model(input_gpu_0)[0], dim=1)
if normalize:
embed_norm = torch.norm(
embed, p=2, dim=1, keepdim=True).clamp(min=1e-12)
embed = embed / embed_norm
embed_np = embed.cpu().detach().numpy()
if now_count == 0:
output = embed_np
else:
output = np.concatenate((output, embed_np), axis=0)
now_count = min(now_count + self.pred_batch_size, count)
pbar.update(self.pred_batch_size)
return output
if __name__ == "__main__":
phrases = ["abs"]
device = torch.device("cuda:0")
encoder = TextEncoder("pubmedbert", device)
print(encoder.get_embed(phrases).shape)text_encode
Overview
The text_encode.py script provides a framework for encoding text phrases into vector embeddings using pre-trained models from Hugging Face’s transformers library. It is designed for biomedical and general-purpose text encoding, with support for several pre-trained models such as PubMedBERT, SapBERT, and BERT. This script is especially useful for applications in natural language processing (NLP) tasks like similarity computation and classification.
Key Components
config_dict
A dictionary mapping model names to their respective Hugging Face identifiers. This allows for flexibility in switching between different models.
TextEncoder Class
A class for encoding text into vector embeddings. It includes the following features:
Initialization:
- Loads the specified pre-trained model and tokenizer.
- Configures device settings (e.g., CPU or GPU).
- Allows optional checking of maximum input length for tokenization.
Methods:
check_max_length(inputs): Analyzes if input phrases exceed the maximum allowable length for the model.get_embed(phrase_list, normalize=True, summary_method="CLS"): Encodes a list of phrases into embeddings using the specified summary method and normalization.
Batch Processing:
- Processes inputs in chunks to optimize memory usage and performance.
Output:
- Returns embeddings as NumPy arrays.