| from fastai.vision.all import * |
| image_extensions.add('.webp') |
|
|
| class MultiCaReClassifier(): |
|
|
| def __init__(self, image_folder, models_root = 'MultiCaReClassifier/models', save_path = '', add_multiclass_columns = False): |
|
|
| '''Class used to classify medical images considering their types (such as ultrasound or MRI), and the corresponding anatomical region and view (for radiology images only). |
| image_folder (str): folder containing all the input images. |
| models_root (str): folder containing the image classification models. |
| save_path (str): path to save the inference table. |
| add_multiclass_columns (bool): if True, multiclass columns will be added to the dataframe based on the multilabel column ('label_list').''' |
|
|
| self.image_folder = os.path.join(image_folder, '') |
| self.models_root = models_root |
| self.save_path = save_path |
| self.add_multiclass_columns = add_multiclass_columns |
|
|
| |
| self.label_dict = { |
| "image_type:radiology~anatomical_region:axial_region": ["abdomen", "breast", "head", "neck", "pelvis", "thorax"], |
| "image_type:radiology~anatomical_region:lower_limb": ["ankle", "foot", "hip", "knee", "lower_leg", "thigh"], |
| "image_type:radiology~anatomical_view": ["axial", "frontal", "intravascular", "oblique", "occlusal", "panoramic", "periapical", "sagittal", "transabdominal", "transesophageal", "transthoracic", "transvaginal"], |
| "image_type:endoscopy": ["airway_endoscopy", "arthroscopy", "ig_endoscopy", "other_endoscopy"], |
| "image_type:electrography": ["eeg", "ekg", "emg"], |
| "image_type:ophthalmic_imaging": ["autofluorescence", "b_scan", "fundus_photograph", "gonioscopy", "oct", "ophtalmic_angiography", "slit_lamp_photograph"], |
| "image_type:radiology~anatomical_region:upper_limb": ["elbow", "forearm", "hand", "shoulder", "upper_arm", "wrist"], |
| "image_type:radiology~anatomical_region": ["axial_region", "lower_limb", "upper_limb", "whole_body"], |
| "image_type:radiology~main": ["ct", "mri", "pet", "scintigraphy", "spect", "tractography", "ultrasound", "x_ray"], |
| "image_type:pathology": ["acid_fast", "alcian_blue", "congo_red", "fish", "giemsa", "gram", "h&e", "immunostaining", "masson_trichrome", "methenamine_silver", "methylene_blue", "papanicolaou", "pas", "van_gieson"], |
| "image_type:radiology~anatomical_region:axial_region.thorax": ["cardiac_image", "other_thoracic_image"], |
| "image_type:medical_photograph": ["oral_photograph", "other_medical_photograph", "skin_photograph"], |
| "image_type": ["chart", "electrography", "endoscopy", "medical_photograph", "ophthalmic_imaging", "pathology", "radiology"] |
| } |
|
|
| |
| self.image_paths = get_image_files(self.image_folder) |
| self.data = pd.DataFrame(columns=[name for name in self.label_dict.keys() if os.path.isdir(os.path.join('models', name.replace(':', '_')))]) |
| self.data['image_path'] = self.image_paths |
| self.predict_image_classes() |
|
|
| |
|
|
| def predict_image_classes(self): |
|
|
| '''Method used to get the predictions for each image.''' |
|
|
| |
| model_order = 1 |
| while True: |
| order_count = 0 |
| for model_name in self.label_dict.keys(): |
| if len(re.split(r'[:.]', model_name)) == model_order: |
| self._add_predictions(model_name) |
| order_count += 1 |
| if order_count == 0: |
| break |
| model_order += 1 |
| |
| |
| self.apply_postprocessing() |
| if self.save_path: |
| self.data.to_csv(self.save_path, index=None) |
|
|
| def apply_postprocessing(self): |
|
|
| '''Method used to postprocess the predictions.''' |
|
|
| |
| columns_to_flatten = [c for c in self.data.columns if c.startswith('image_type')] |
| self.data['label_list'] = self.data[columns_to_flatten].values.tolist() |
| self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if isinstance(item, (str, np.str_))]) |
| self.data.drop(columns_to_flatten, axis = 1, inplace = True) |
|
|
| |
| replacement_dict = {'transesophageal': 'ultrasound_view', 'transthoracic': 'ultrasound_view', 'transabdominal': 'ultrasound_view', |
| 'transvaginal': 'ultrasound_view', 'ophtalmic_angiography': 'ophthalmic_angiography', 'ig_endoscopy': 'gi_endoscopy'} |
|
|
| self.data['label_list'] = self.data['label_list'].apply(lambda x: [replacement_dict.get(item, item) for item in x]) |
|
|
| |
| self.data['label_list'] = self.data['label_list'].apply(lambda x: self._add_compound_classes(x)) |
|
|
| |
| if self.add_multiclass_columns: |
| self._generate_multiclass_columns() |
|
|
| |
| auxiliary_classes = ['axial_region', 'cardiac_image', 'other_thoracic_image', 'intravascular', 'ultrasound_view'] |
| self.data['label_list'] = self.data['label_list'].apply(lambda x: [item for item in x if item not in auxiliary_classes]) |
|
|
| |
|
|
| def _identify_upper_model(self, model_name): |
|
|
| '''Method used to identify the corresponding upper model of a given model.''' |
|
|
| colon_index = self._search_last_match(model_name, ':') |
| dot_index = self._search_last_match(model_name, '.') |
| index = max(colon_index, dot_index) |
| if index != -1: |
| return model_name[:index] |
| else: |
| return None |
|
|
| def _search_last_match(self, string, character): |
|
|
| '''Method used to find the last mention of a character in a string.''' |
|
|
| if character in string: |
| return string.rindex(character) |
| else: |
| return -1 |
|
|
| def _add_predictions(self, model_name): |
|
|
| '''Method used to add all the predictions of a given model to the outcome dataframe.''' |
|
|
| upper_model = self._identify_upper_model(model_name) |
|
|
| |
| if upper_model is not None: |
| condition_class = model_name.split(':')[-1].split('~')[0].split('.')[-1] |
| condition = self.data[model_name].isnull() & (self.data[upper_model] == condition_class) |
| else: |
| condition = self.data[model_name].isnull() |
| imgs = self.data[condition].image_path.values |
|
|
| labels = np.array(self.label_dict[model_name]) |
|
|
| |
| if len(imgs) > 0: |
| device = 'cpu' |
| |
| checkpoint_file = os.path.join(model_name.replace(':', '_'), 'model') |
| dls = ImageDataLoaders.from_path_func('', imgs, lambda x: '0', item_tfms=Resize((224,224), method='squish')) |
| learn = vision_learner(dls, resnet50, n_out=len(labels)).to_fp16() |
| learn.load(checkpoint_file, device=device, weights_only=False) |
| test_dl = learn.dls.test_dl(imgs, device=device) |
| probs, _ = learn.get_preds(dl=test_dl) |
| self.data.loc[condition, model_name] = labels[probs.argmax(axis=1)] |
| |
| def _add_compound_classes(self, input_class_list): |
|
|
| '''This method is used to add compound classes to the label list if the corresponding component classes are present.''' |
|
|
| compound_class_dicts = [ |
| {'compound_class': 'echocardiogram', 'components': ['ultrasound', 'cardiac_image']}, |
| {'compound_class': 'ivus', 'components': ['ultrasound', 'intravascular']}, |
| {'compound_class': 'mammography', 'components': ['x_ray', 'breast']} |
| ] |
|
|
| for dct in compound_class_dicts: |
| condition = True |
| for cls in dct['components']: |
| if cls not in input_class_list: |
| condition = False |
| break |
| if condition: |
| if dct['compound_class'] not in input_class_list: |
| input_class_list.append(dct['compound_class']) |
|
|
| return input_class_list |
|
|
| def _generate_multiclass_columns(self): |
|
|
| '''Method used to generate the multiclass columns based on the label list.''' |
|
|
| image_types = ['chart', 'radiology', 'pathology', 'medical_photograph', 'ophthalmic_imaging', 'endoscopy', 'electrography'] |
| self.data['image_type'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_types)) |
|
|
| image_subtypes = ['chart', |
| 'ct', 'mri', 'x_ray', 'pet', 'spect', 'scintigraphy', 'ultrasound', 'tractography', |
| 'acid_fast', 'alcian_blue', 'congo_red', 'fish', 'giemsa', 'gram', 'h&e', 'immunostaining', 'masson_trichrome', 'methenamine_silver', 'methylene_blue', 'papanicolaou', 'pas', 'van_gieson', |
| 'skin_photograph', 'oral_photograph', 'other_medical_photograph', |
| 'b_scan', 'autofluorescence', 'fundus_photograph', 'gonioscopy', 'oct', 'ophthalmic_angiography', 'slit_lamp_photograph', |
| 'gi_endoscopy', 'airway_endoscopy', 'other_endoscopy', 'arthroscopy', |
| 'eeg', 'emg', 'ekg'] |
|
|
| self.data['image_subtype'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, image_subtypes)) |
|
|
| anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax', |
| 'lower_limb', 'upper_limb', 'whole_body'] |
|
|
| self.data['radiology_region'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_regions)) |
|
|
| granular_anatomical_regions = ['abdomen', 'breast', 'head', 'neck', 'pelvis', 'thorax', |
| 'ankle', 'foot', 'hip', 'knee', 'lower_leg', 'thigh', |
| 'elbow', 'forearm', 'hand', 'shoulder', 'upper_arm', 'wrist', |
| 'whole_body'] |
|
|
| self.data['radiology_region_granular'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, granular_anatomical_regions)) |
|
|
| anatomical_view = ['axial', 'frontal', 'sagittal', 'oblique', |
| 'occlusal', 'panoramic', 'periapical', 'intravascular', 'ultrasound_view'] |
|
|
| self.data['radiology_view'] = self.data['label_list'].apply(lambda x: self._get_column_label(x, anatomical_view)) |
|
|
| def _get_column_label(self, column_list, label_list): |
|
|
| '''Method used to get the label from a relevant list that is present in the predictions of a given image.''' |
|
|
| label = '' |
| for column in column_list: |
| if column in label_list: |
| label = column |
| return label |