Al fin entendí para qué se usaba el Factory Pattern

Tuve que cambiar frecuentemente entre varias variantes de un modelo, cada una de ellas creada con distintos argumentos. Añadir variantes nuevas se volvió insostenible, pero el factory pattern mejoró todo.

Estoy trabajando en reconstruir un vídeo usando una representación neuronal implícita, que es simplemente un fully connected multilayer perceptron que recibe las coordenadas $x,y,t$ de un píxel y entrega la intensidad (o el color) del vídeo en esa coordenada. Hacer esto de forma directa entrega vídeos muy borrosos y sin detalles, por lo que se suele añadir un preprocesamiento a las coordenadas llamado Fourier Features (ver la página del proyecto para más detalles).

Puesto que Fourier Features originalmente está pensado para imágenes (2D), y yo quiero usar vídeos (3D), estuvimos discutiendo varias posibles opciones de extenderlo, así que tuve que probarlas y cambiar frecuentemente entre una y otra. Aunque me sentí muy agradecida por haber puesto un enchufe entre mi lámpara y los cables, esto no fue suficiente ya que cada nueva variante requería hacer cambios en varias partes del código, lo que indicaba que no había seguido la idea de “mantener en un mismo lugar las cosas que cambian por las mismas razones”.

Varias modelos que siguen la misma interfaz pero se inicializan de formas distintas

El tener que probar distintas versiones de Fourier Features, todas con un funcionamiento muy similar salvo una capa de preprocesamiento, me llevó a crear la clase abstracta FFnet_2Dspace_time, que resolvió una parte del problema permitiendo hacer dependency inversion, de forma que el entrenamiento no dependiera de los detalles de implementación específicos de cada red. Esto ha sido especialemente útil últimamente porque he tenido que probar distintas variantes de preprocesamiento. Los distintos modelos se ven así:

# fourier_features.py
from abc import ABC, abstractmethod 

class FFnet_2Dspace_time(ABC):
  ... 

class FFx_evenpt_net(FFnet_2Dspace_time):
  def __init__(self, mapping_size, sigma, m, gtimg, key): 
    ...
  ...

class FF_fraction_static_mixed_net(FFnet_2Dspace_time):
  def __init__(self, sigma, static_fraction, desired_ffvector_len, gtimg, key): 
    ...
  ...  

Es decir, tengo varias clases, cada una con su propia inicialización, que satisfacen la misma interfaz.

Muchos cambios requeridos para agregar un modelo nuevo

Actualmente mi script de entrenamiento recibe argumentos desde un argparser, lo que se ve así:


from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

parser = ArgumentParser(description="Train from data stored in npy file", formatter_class=ArgumentDefaultsHelpFormatter)

parser.add_argument("--chosen_patient", help="Paciente elegido (P01, P02, ...etc). Ver disponibles en HARVARD_DB_IDs, harvardDB_processing.py", type=str, default="P01")
...

Esto permite llamar al script con distintos parámetros desde un archivo .sh:

python training.py --main_folder "results" --chosen_patient "P02" --nIter 20001 --batch_size 2 --tvval 0 --hermitic_fill --mapping_size 200 --sigma 4.5 --sub_spokes_per_frame 16 --laxmap  

El problema es que no tengo una buena forma de cambiar la subclase de FFnet_2Dspace_time a usar (la llamaré FFnet) desde el archivo .sh. Cada vez que quiero agregar una nueva FFnet debo hacer los siguientes cambios:

  • Añadir la subclase al archivo fourier_features.py. Eso no es problema, ya que esa parte del código es relativamente fácil de extender.
  • Cambiar el parser, de forma que permita seleccionar la nueva FFnet. También debe poder recibir los argumentos (que pueden ser distintos de los de las otras redes).
  • Cambiar la parte de la preparación del entrenamiento donde creo la FFnet. Debo leer el parser, obtener los argumentos y crear el objeto. Esto incluye un montón de ifs y es, desde luego, un desastre a la hora de hacer cambios 😩. Ver código:
if args_dic["mixed"]: 
    FFnet = FFx_mixedff_space_time_net(...) 
elif args_dic["static_mixed"]: 
    FFnet = FF_static_mixed_net(...) 
elif args_dic["p_static_mixed"] is not None: 
    FFnet = FF_fraction_static_mixed_net(...)

params = FFnet.init_params(...)
  • Cambiar la clase que se usa para hacer las reconstrucciones a partir de los resultados guardados en el entrenamiento. Esta clase necesita cargar la FFnet, y tiene un código similar al de arriba.

Así que, con mi código actual, necesito hacer “una cirugía de pecho para cambiarme el abrigo”, una violación del Principio Abierto-Cerrado (Open-Closed) de los Principios SOLID.

OCP

Un meme clásico para explicar el Open-Closed Principle.

Factory Pattern al rescate

Creo que mi problema suele resolverse con el factory pattern. La idea es simple: separar la creación de un objeto de su uso. Puesto que a mi entrenamiento le es indiferente la FFnet específica que uso, no hay motivo para que la creación de la FFnet esté en training.py.

Mi solución fue:

  • Añadir a cada clase FFnet un objeto con la información necesaria para crear un parser. Esto incluye una palabra corta (por ejemplo, frac-static-mixed para referirse a la red FF_fraction_static_mixed_net), la cantidad de argumentos que requiere y una ayuda (“el primero es un float que representa….blah…blah…”). Este debe estar asociado a la clase, no a una instancia específica.
  • Crear automáticamente un parser a partir de la información de las clases disponibles.
  • Crear una función que reciba el diccionario parseado y entregue la FFnet adecuada (con la instancia correcta). Esta es la fábrica o factory.
  • Reemplazar la creación de las FFnets con esa función.

Un código de ejemplo que prueba la idea:

from dataclasses import dataclass 

@dataclass 
class RazaParser(): 
    shortcut: str 
    nargs: int 
    help: str 


def find_correct_class(subclass_map, adict): 
    for shortcutkey, theclass in subclass_map.items(): 
        if adict[shortcutkey] is not None: 
            return shortcutkey, theclass 
 

from abc import ABC , abstractmethod

class Perro(ABC):

    @abstractmethod
    def get_parser() -> RazaParser: 
        pass 

    @abstractmethod
    def ladrido() -> None: 
        pass 

class Chihuahua(Perro): 
    def __init__(self, color_ropa): 
        print(f"Soy gruñon y me visten de color {color_ropa}")

    def get_parser() -> RazaParser:
        return RazaParser('chihuahua', 1, "color ropa: str. Ejemplo: morado.")   

    def ladrido(self):
        print("guau")

class PastorAleman(Perro): 
    def __init__(self, is_police:bool, cuteness:str):
        print(f"Yo {'no ' if not is_police else ''}soy un perro policía {cuteness}") 

    def get_parser() -> RazaParser:
        return RazaParser('pastor', 2, "- bool: si true, es un perro policia. - cuteness: str, gruñon, muy tierno, etc.")
    
    def ladrido(self):
        print("barf")

def create_perro_from_parser(adict): 
    subclass_map = {subclass.get_parser().shortcut: subclass for subclass in Perro.__subclasses__()}
    shortcut, subclass = find_correct_class(subclass_map, adict)
    paramslist = adict[shortcut]
    instance = super(Perro, subclass).__new__(subclass)
    instance.__init__(*paramslist) 
    return instance

Con este código, puedo entregar un diccionario de argumentos de la forma

args1 = {'pastor': [True, 'muy lindo'], 'chihuahua': None}
args2 = {'pastor': None, 'chihuahua': ['azul'], 'juguete': "pollo chillón"}

Notar como los keys asociados a una clase (pastor, chihuahua) tienen listas asocidas, o Nones; esto es porque pretendo usar mutual exclusion y nargs, algo así:

group = parser.add_mutually_exclusive_group()
group.add_argument("--perro",nargs=2)
group.add_argument("--chihuahua",nargs=1)

Notar que el diccionario puede incluir argumentos que no tienen nada que ver con las clases (juguete por ejemplo).

La función create_perro_from_parser puede crear e inicializar un objeto adecuado:

>>> perro = create_perro_from_parser(args1)
Yo soy un perro policía muy lindo
>>> perro.ladrido()
barf 
>>> perro = create_perro_from_parser(args2)
Soy gruñon y me visten de color azul
>>> perro.ladrido()
guau

Algunas consideraciones finales

Soy consciente de algunos problemas importantes con esta solución, que intentaré mitigar:

  • Va a introducir breaking changes en el código, específicamente en el procesamiento de los resultados, ya que va a cambiar la estructura de los diccionarios de argumentos. Puedo adaptar fácilmente el código para hacerlo compatible con los resultados actuales, pero tendré que mantener trozos antiguos de código para no perder la backward compatibility con los resultados antiguos.
  • Necesito que la Factory sea resiliente a los cambios en el diccionario de argumentos generados por la adición de una nueva Fnet. Sin esto, cada vez que añada un nueva FFnet, la factory no será capaz de trabajar con un diccionario de una versión anterior.