from transformers import Pipeline class MultitaskTokenClassificationPipeline(Pipeline): def _sanitize_parameters(self, **kwargs): preprocess_kwargs = {} if "text" in kwargs: preprocess_kwargs["text"] = kwargs["text"] return preprocess_kwargs, {}, {} def preprocess(self, text, **kwargs): return text def _forward(self, text): print(f"Do we arrive here? {text}") print(f"Let's check the model: {self.model.get_floret_model()}") # predictions, probabilities = self.model.get_floret_model().predict([text], k=1) self.model(text) return text def postprocess(self, text, **kwargs): """ Postprocess the outputs of the model :param outputs: :param kwargs: :return: """ # print(f"Let's check the model: {self.model.get_floret_model()}") # predictions, probabilities = self.model.get_floret_model().predict([text], k=1) # # label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix # confidence = float( # probabilities[0][0] # ) # Convert to float for JSON serialization # # # Format as JSON-compatible dictionary # model_output = {"label": label, "confidence": round(confidence * 100, 2)} # print("Formatted Model Output:", model_output) return text