11import json
22import logging
33import os
4+ import time
45from collections .abc import Iterator
5- from dataclasses import dataclass
66from pathlib import Path
7+ from typing import Literal
78
89import pytest
10+ import urllib3
911import vectorize_client as v
12+ from vectorize_client import ApiClient
1013
1114from langchain_vectorize .retrievers import VectorizeRetriever
1215
1316
14- @dataclass
15- class TestContext :
16- api_client : v .ApiClient
17- api_token : str
18- org_id : str
19-
20-
2117@pytest .fixture (scope = "session" )
2218def api_token () -> str :
2319 token = os .getenv ("VECTORIZE_TOKEN" )
2420 if not token :
25- msg = "Please set VECTORIZE_TOKEN environment variable"
21+ msg = "Please set the VECTORIZE_TOKEN environment variable"
2622 raise ValueError (msg )
2723 return token
2824
@@ -31,21 +27,29 @@ def api_token() -> str:
3127def org_id () -> str :
3228 org = os .getenv ("VECTORIZE_ORG" )
3329 if not org :
34- msg = "Please set VECTORIZE_ORG environment variable"
30+ msg = "Please set the VECTORIZE_ORG environment variable"
3531 raise ValueError (msg )
3632 return org
3733
3834
3935@pytest .fixture (scope = "session" )
40- def api_client ( api_token : str ) -> Iterator [ TestContext ]:
36+ def environment ( ) -> Literal [ "prod" , "dev" , "local" , "staging" ]:
4137 env = os .getenv ("VECTORIZE_ENV" , "prod" )
38+ if env not in ["prod" , "dev" , "local" , "staging" ]:
39+ msg = "Invalid VECTORIZE_ENV environment variable."
40+ raise ValueError (msg )
41+ return env
42+
43+
44+ @pytest .fixture (scope = "session" )
45+ def api_client (api_token : str , environment : str ) -> Iterator [ApiClient ]:
4246 header_name = None
4347 header_value = None
44- if env == "prod" :
48+ if environment == "prod" :
4549 host = "https://api.vectorize.io/v1"
46- elif env == "dev" :
50+ elif environment == "dev" :
4751 host = "https://api-dev.vectorize.io/v1"
48- elif env == "local" :
52+ elif environment == "local" :
4953 host = "http://localhost:3000/api"
5054 header_name = "x-lambda-api-key"
5155 header_value = api_token
@@ -87,8 +91,6 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
8791 ),
8892 )
8993
90- import urllib3
91-
9294 http = urllib3 .PoolManager ()
9395 this_dir = Path (__file__ ).parent
9496 file_path = this_dir / "research.pdf"
@@ -137,7 +139,9 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
137139 config = {},
138140 ),
139141 ai_platform = v .AIPlatformSchema (
140- id = builtin_ai_platform , type = v .AIPlatformType .VECTORIZE , config = v .AIPlatformConfigSchema ()
142+ id = builtin_ai_platform ,
143+ type = v .AIPlatformType .VECTORIZE ,
144+ config = v .AIPlatformConfigSchema (),
141145 ),
142146 pipeline_name = "Test pipeline" ,
143147 schedule = v .ScheduleSchema (type = v .ScheduleSchemaType .MANUAL ),
@@ -154,20 +158,48 @@ def pipeline_id(api_client: v.ApiClient, org_id: str) -> Iterator[str]:
154158 logging .exception ("Failed to delete pipeline %s" , pipeline_id )
155159
156160
157- def test_retrieve_init_args (api_token : str , org_id : str , pipeline_id : str ) -> None :
161+ def test_retrieve_init_args (
162+ environment : Literal ["prod" , "dev" , "local" , "staging" ],
163+ api_token : str ,
164+ org_id : str ,
165+ pipeline_id : str ,
166+ ) -> None :
158167 retriever = VectorizeRetriever (
159- api_token = api_token , organization = org_id , pipeline_id = pipeline_id , num_results = 2
160- )
161- docs = retriever .invoke (input = "What are you?" )
162- assert len (docs ) == 2
163-
164-
165- def test_retrieve_invoke_args (api_token : str , org_id : str , pipeline_id : str ) -> None :
166- retriever = VectorizeRetriever (api_token = api_token )
167- docs = retriever .invoke (
168- input = "What are you?" ,
168+ environment = environment ,
169+ api_token = api_token ,
169170 organization = org_id ,
170171 pipeline_id = pipeline_id ,
171172 num_results = 2 ,
172173 )
173- assert len (docs ) == 2
174+ start = time .time ()
175+ while True :
176+ docs = retriever .invoke (input = "What are you?" )
177+ if len (docs ) == 2 :
178+ break
179+ if time .time () - start > 180 :
180+ msg = "Docs not retrieved in time"
181+ raise RuntimeError (msg )
182+ time .sleep (1 )
183+
184+
185+ def test_retrieve_invoke_args (
186+ environment : Literal ["prod" , "dev" , "local" , "staging" ],
187+ api_token : str ,
188+ org_id : str ,
189+ pipeline_id : str ,
190+ ) -> None :
191+ retriever = VectorizeRetriever (environment = environment , api_token = api_token )
192+ start = time .time ()
193+ while True :
194+ docs = retriever .invoke (
195+ input = "What are you?" ,
196+ organization = org_id ,
197+ pipeline_id = pipeline_id ,
198+ num_results = 2 ,
199+ )
200+ if len (docs ) == 2 :
201+ break
202+ if time .time () - start > 180 :
203+ msg = "Docs not retrieved in time"
204+ raise RuntimeError (msg )
205+ time .sleep (1 )
0 commit comments