diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 1aa849fa..e79aa915 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -156,6 +156,13 @@ def format_for_llm(self) -> TemplateRepresentation | str: template_order=["*", "Requirement"], ) + def __and__(self, other): + return ConjunctiveRequirement([self,other]) + def __or__(self, other): + return DisjunctiveRequirement([self,other]) + def __not__(self): + return NegativeRequirement(self) + class LLMaJRequirement(Requirement): """A requirement that always uses LLM-as-a-Judge. Any available constraint ALoRA will be ignored.""" @@ -178,6 +185,98 @@ def __init__(self, description: str, alora: Alora | None = None): self.alora = alora +class ConjunctiveRequirement(Requirement): + def __init__(self, requirements: list[Requirement],): + self.requirements = requirements + + @property + def description(self): + return "\n* ".join( + ["Satisfy all of these requirements:"] + \ + [r.description for r in self.requirements]) + + def validate(self, *args, **kwargs): + results = [r.validate(*args, **kwargs) for r in self.requirements] + return ValidationResult( + result = all(results), + reason = "\n* ".join( + ["These requirements are not satisfied:"]+ + [r.reason for r in results if not r]), + score = max([r.score for r in results if not r])) + + def __and__(self, other): + match other: + case ConjunctiveRequirement(): + ConjunctiveRequirement(self.requirements+other.requirements) + case Requirement(): + ConjunctiveRequirement(self.requirements+[other]) + + def __or__(self, other): + return DisjunctiveRequirement([self,other]) + def __not__(self): + return NegativeRequirement(self) + + + +class DisjunctiveRequirement(Requirement): + def __init__(self, requirements: list[Requirement],): + self.requirements = requirements + + @property + def description(self): + return "\n* ".join( + ["Satisfy at least one of these requirements:"] + \ + [r.description for r in self.requirements]) + + def validate(self, *args, **kwargs): + results = [r.validate(*args, **kwargs) for r in self.requirements] + return ValidationResult( + result = any(results), + reason = "\n* ".join( + ["These requirements are satisfied:"]+ + [r.reason for r in results if r]), + score = min([r.score for r in results if not r])) + + def __and__(self, other): + return ConjunctiveRequirement([self,other]) + def __or__(self, other): + match other: + case DisjunctiveRequirement(): + DisjunctiveRequirement(self.requirements+other.requirements) + case Requirement(): + DisjunctiveRequirement(self.requirements+[other]) + def __not__(self): + return NegativeRequirement(self) + + +class NegativeRequirement(Requirement): + def __init__(self, requirement: Requirement,): + self.requirement = requirement + + @property + def description(self): + return f"Do not satisfy this requirement: {self.requirement.description}" + + def __getattr__(self, name): + # delegate lookup to self.requirement + return getattr(self.requirement, name) + + def validate(self, *args, **kwargs): + result = self.requirement.validate(*args, **kwargs) + return ValidationResult( + result = not result, + reason = result.reason, + # score = ??? + ) + + def __and__(self, other): + return ConjunctiveRequirement([self,other]) + def __or__(self, other): + return DisjunctiveRequirement([self,other]) + def __not__(self): + return self.requirement + + def reqify(r: str | Requirement) -> Requirement: """Maps strings to Requirements.