Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"simple_rag_with_filter.py",
"mcp_example.py",
"client.py",
"pii_serve.py",
}


Expand Down
4 changes: 2 additions & 2 deletions mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
self.model_options = model_options if model_options is not None else {}

@abc.abstractmethod
def generate_from_context(
async def generate_from_context(
self,
action: Component | CBlock,
ctx: Context,
Expand All @@ -58,7 +58,7 @@ def generate_from_context(
...

@abc.abstractmethod
def generate_from_raw(
async def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
Expand Down
2 changes: 1 addition & 1 deletion mellea/backends/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, responses: list[str] | None):
self.responses = responses
self.idx = 0

def generate_from_context(
async def generate_from_context(
self,
action: Component | CBlock,
ctx: Context,
Expand Down
22 changes: 13 additions & 9 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
self._added_adapters: dict[str, LocalHFAdapter] = {}
self._loaded_adapters: dict[str, LocalHFAdapter] = {}

def generate_from_context(
async def generate_from_context(
self,
action: Component | CBlock,
ctx: Context,
Expand Down Expand Up @@ -229,21 +229,23 @@ def generate_from_context(

if reroute_to_alora:
# Keep the alora requirement handling separate for now.
mot = self._generate_from_intrinsic(
mot = await self._generate_from_intrinsic(
alora_action, ctx, model_options=model_opts
)
return mot, ctx.add(alora_action).add(mot)

elif isinstance(action, Intrinsic):
mot = self._generate_from_intrinsic(action, ctx, model_options=model_opts)
mot = await self._generate_from_intrinsic(
action, ctx, model_options=model_opts
)
return mot, ctx.add(action).add(mot)

mot = self._generate_from_context_standard(
mot = await self._generate_from_context_standard(
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
)
return mot, ctx.add(action).add(mot)

def _generate_from_intrinsic(
async def _generate_from_intrinsic(
self, action: Intrinsic, ctx: Context, *, model_options: dict[str, Any]
) -> ModelOutputThunk:
if not ctx.is_chat_context:
Expand Down Expand Up @@ -394,7 +396,7 @@ async def granite_common_processing(

return output

def _generate_from_context_standard(
async def _generate_from_context_standard(
self,
action: Component | CBlock,
ctx: Context,
Expand Down Expand Up @@ -627,7 +629,7 @@ async def post_processing(

mot._generate_log = generate_log

def generate_from_raw(
async def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
Expand Down Expand Up @@ -663,7 +665,8 @@ def generate_from_raw(
)

if format is None:
outputs = self._model.generate( # type: ignore
outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
Expand All @@ -681,7 +684,8 @@ def generate_from_raw(
from outlines.processors import RegexLogitsProcessor
from transformers import LogitsProcessorList

outputs = self._model.generate( # type: ignore
outputs = await asyncio.to_thread(
self._model.generate, # type: ignore
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
return_dict_in_generate=True,
Expand Down
78 changes: 72 additions & 6 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(

self._past_event_loops: set[int] = set()

def generate_from_context(
async def generate_from_context(
self,
action: Component | CBlock,
ctx: Context,
Expand All @@ -123,7 +123,7 @@ def generate_from_context(
assert ctx.is_chat_context, NotImplementedError(
"The Openai backend only supports chat-like contexts."
)
mot = self._generate_from_chat_context_standard(
mot = await self._generate_from_chat_context_standard(
action,
ctx,
_format=format,
Expand Down Expand Up @@ -231,7 +231,7 @@ def _make_backend_specific_and_remove(

return backend_specific

def _generate_from_chat_context_standard(
async def _generate_from_chat_context_standard(
self,
action: Component | CBlock,
ctx: Context,
Expand Down Expand Up @@ -448,7 +448,7 @@ async def post_processing(
"format": _format,
"tools_available": tools,
"tools_called": mot.tool_calls,
"seed": thinking,
"thinking": thinking,
}
generate_log.action = mot._action
generate_log.result = mot
Expand All @@ -474,7 +474,7 @@ def _extract_tools(
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
return tools

def generate_from_raw(
async def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
Expand All @@ -484,7 +484,73 @@ def generate_from_raw(
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
raise NotImplementedError("This method is not implemented yet.")
extra_body = {}
if format is not None:
FancyLogger.get_logger().warning(
"The official OpenAI completion api does not accept response format / structured decoding; "
"it will be passed as an extra arg."
)

# Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests.
extra_body["guided_json"] = format.model_json_schema()
if tool_calls:
FancyLogger.get_logger().warning(
"The completion endpoint does not support tool calling."
)

# We don't do anything fancy for model_opts with generate from raw; litellm has too many potential options depending on provider.
model_opts = self._simplify_and_merge(model_options)
model_specific_options = self._make_backend_specific_and_remove(model_opts)

if self._has_potential_event_loop_errors():
FancyLogger().get_logger().warning(
"There is a known bug with litellm. This generation call may fail. If it does, you should ensure that you are either running only synchronous Mellea functions or running async Mellea functions from one asyncio.run() call."
)

prompts = [self.formatter.print(action) for action in actions]

completion_response = await litellm.atext_completion(
model=self._model_id, prompt=prompts, **model_specific_options
)

# Necessary for type checker.
assert isinstance(completion_response, litellm.TextCompletionResponse) # type: ignore

results = []
date = datetime.datetime.now()
responses = completion_response.choices
if len(responses) != len(prompts):
FancyLogger().get_logger().error(
"litellm appears to have sent your batch request as a single message; this typically happens with providers like ollama that don't support batching"
)

for res, action, prompt in zip(responses, actions, prompts):
output = ModelOutputThunk(res.text) # type: ignore
output._context = None # There is no context for generate_from_raw for now
output._action = action
output._model_options = model_opts
output._meta = {
"litellm_chat_response": res.model_dump(),
"usage": completion_response.usage.model_dump()
if completion_response.usage
else None,
}

self.formatter.parse(action, output)

generate_log = GenerateLog()
generate_log.prompt = prompt
generate_log.backend = f"litellm::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = date
generate_log.model_output = completion_response
generate_log.extra = {"seed": model_opts.get("seed", None)}
generate_log.action = action
output._generate_log = generate_log

results.append(output)

return results

def _extract_model_tool_requests(
self,
Expand Down
43 changes: 19 additions & 24 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _make_backend_specific_and_remove(
)
return ModelOption.remove_special_keys(backend_specific)

def generate_from_context(
async def generate_from_context(
self,
action: Component | CBlock,
ctx: Context,
Expand All @@ -265,7 +265,7 @@ def generate_from_context(
assert ctx.is_chat_context, (
"The ollama backend only supports chat-like contexts."
)
mot = self.generate_from_chat_context(
mot = await self.generate_from_chat_context(
action,
ctx,
_format=format,
Expand All @@ -275,7 +275,7 @@ def generate_from_context(

return mot, ctx.add(action).add(mot)

def generate_from_chat_context(
async def generate_from_chat_context(
self,
action: Component | CBlock,
ctx: Context,
Expand Down Expand Up @@ -375,6 +375,8 @@ def generate_from_chat_context(

# This function should always be called from a running event loop so we don't have to worry about
# scheduling the task to a specific event loop here.

# Use `create_task` so that we don't have to specifically await this task before it starts executing.
output._generate = asyncio.create_task(
send_to_queue(chat_response, output._async_queue)
)
Expand All @@ -385,7 +387,7 @@ def generate_from_chat_context(

return output

def generate_from_raw(
async def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
Expand All @@ -410,27 +412,20 @@ def generate_from_raw(
# See https://github.com/ollama/ollama/blob/main/docs/faq.md#how-does-ollama-handle-concurrent-requests.
prompts = [self.formatter.print(action) for action in actions]

async def get_response():
# Run async so that we can make use of Ollama's concurrency.
coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = []
for prompt in prompts:
co = self._async_client.generate(
model=self._get_ollama_model_id(),
prompt=prompt,
raw=True,
think=model_opts.get(ModelOption.THINKING, None),
format=format.model_json_schema() if format is not None else None,
options=self._make_backend_specific_and_remove(model_opts),
)
coroutines.append(co)

responses = await asyncio.gather(*coroutines, return_exceptions=True)
return responses
# Run async so that we can make use of Ollama's concurrency.
coroutines: list[Coroutine[Any, Any, ollama.GenerateResponse]] = []
for prompt in prompts:
co = self._async_client.generate(
model=self._get_ollama_model_id(),
prompt=prompt,
raw=True,
think=model_opts.get(ModelOption.THINKING, None),
format=format.model_json_schema() if format is not None else None,
options=self._make_backend_specific_and_remove(model_opts),
)
coroutines.append(co)

# Run in the same event_loop like other Mellea async code called from a sync function.
responses: list[ollama.GenerateResponse | BaseException] = _run_async_in_thread(
get_response()
)
responses = await asyncio.gather(*coroutines, return_exceptions=True)

results = []
date = datetime.datetime.now()
Expand Down
Loading