Skip to content

Commit

Permalink
update entity extractor to train single model for all entities
Browse files Browse the repository at this point in the history
  • Loading branch information
alfredfrancis committed Jan 10, 2025
1 parent 6dfdb39 commit 1f29794
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
30 changes: 22 additions & 8 deletions app/bot/nlu/entity_extractors/crf_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pycrfsuite
from app.config import app_config

MODEL_NAME = "crf__entity_extractor.model"

class CRFEntityExtractor:
"""
Performs NER training, prediction, model import/export
Expand All @@ -10,6 +12,7 @@ def __init__(self, synonyms={}):
import spacy
self.tokenizer = spacy.load("en_core_web_md")
self.synonyms = synonyms
self.tagger = None

def replace_synonyms(self, entities):
"""
Expand Down Expand Up @@ -88,7 +91,7 @@ def sent_to_labels(self, sent):
"""
return [label for token, postag, label in sent]

def train(self, train_sentences, model_name):
def train(self, train_sentences, model_path: str):
"""
Train NER model for given model
:param train_sentences:
Expand All @@ -110,9 +113,23 @@ def train(self, train_sentences, model_name):
# include transitions that are possible, but not observed
'feature.possible_transitions': True
})
trainer.train('model_files/%s.model' % model_name)
trainer.train(f"{model_path}/{MODEL_NAME}")
return True

def load(self, model_path: str) -> bool:
"""
Load the CRF model from the given path
:param model_path: Path to the model directory
:return: True if successful, False otherwise
"""
try:
self.tagger = pycrfsuite.Tagger()
self.tagger.open(f"{model_path}/entity_model.model")
return True
except Exception as e:
print(f"Error loading CRF model: {e}")
return False

def crf2json(self, tagged_sentence):
"""
Extract label-value pair from NER prediction output
Expand Down Expand Up @@ -143,19 +160,16 @@ def extract_ner_labels(self, predicted_labels):
labels.append(tp[2:])
return labels

def predict(self, model_name, message):
def predict(self, message):
"""
Predict NER labels for given model and query
:param model_name:
Predict NER labels for given message
:param message:
:return:
"""
spacy_doc = message.get("spacy_doc")
tagged_token = self.pos_tagger(spacy_doc)
words = [token.text for token in spacy_doc]
tagger = pycrfsuite.Tagger()
tagger.open("{}/{}.model".format(app_config.MODELS_DIR, model_name))
predicted_labels = tagger.tag(self.sent_to_features(tagged_token))
predicted_labels = self.tagger.tag(self.sent_to_features(tagged_token))
extracted_entities = self.crf2json(
zip(words, predicted_labels))
return self.replace_synonyms(extracted_entities)
Expand Down
25 changes: 8 additions & 17 deletions app/bot/nlu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,28 +111,19 @@ def __init__(self, synonyms: Optional[Dict[str, str]] = None):
self.extractor = CRFEntityExtractor(synonyms or {})

def train(self, training_data: List[Dict[str, Any]], model_path: str) -> None:
# Group training data by intent
intent_data = {}
for example in training_data:
intent = example.get("intent")
if intent not in intent_data:
intent_data[intent] = []
intent_data[intent].append(example)

# Train model for each intent
for intent_id, examples in intent_data.items():
ner_training_data = self.extractor.json2crf(examples)
self.extractor.train(ner_training_data, intent_id)
# Convert all training data to CRF format at once
ner_training_data = self.extractor.json2crf(training_data)
# Train a single model for all entities
self.extractor.train(ner_training_data, "entity_model")

def load(self, model_path: str) -> bool:
# Entity extractor loads models on demand per intent
return True
# Load the single entity model
return self.extractor.load(model_path)

def process(self, message: Dict[str, Any]) -> Dict[str, Any]:
if not message.get("text") or not message.get("intent", {}).get("intent") or not message.get("spacy_doc"):
if not message.get("text") or not message.get("spacy_doc"):
return message

intent_id = message["intent"]["intent"]
entities = self.extractor.predict(intent_id,message)
entities = self.extractor.predict(message)
message["entities"] = entities
return message

0 comments on commit 1f29794

Please sign in to comment.