web loader support proxy

This commit is contained in:
Yimi81
2025-02-14 07:15:09 +00:00
parent 304aed0f13
commit d3f71930f0
4 changed files with 58 additions and 2 deletions

View File

@@ -1,7 +1,9 @@
import socket
import aiohttp
import asyncio
import urllib.parse
import validators
from typing import Union, Sequence, Iterator
from typing import Union, Sequence, Iterator, Dict
from langchain_community.document_loaders import (
WebBaseLoader,
@@ -68,6 +70,45 @@ def resolve_hostname(hostname):
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def __init__(self, trust_env: bool = False, *args, **kwargs):
"""Initialize SafeWebBaseLoader
Args:
trust_env (bool, optional): set to True if using proxy to make web requests, for example
using http(s)_proxy environment variables. Defaults to False.
"""
super().__init__(*args, **kwargs)
self.trust_env = trust_env
async def _fetch(
self, url: str, retries: int = 3, cooldown: int = 2, backoff: float = 1.5
) -> str:
async with aiohttp.ClientSession(trust_env=self.trust_env) as session:
for i in range(retries):
try:
kwargs: Dict = dict(
headers=self.session.headers,
cookies=self.session.cookies.get_dict(),
)
if not self.session.verify:
kwargs["ssl"] = False
async with session.get(
url, **(self.requests_kwargs | kwargs)
) as response:
if self.raise_for_status:
response.raise_for_status()
return await response.text()
except aiohttp.ClientConnectionError as e:
if i == retries - 1:
raise
else:
log.warning(
f"Error fetching {url} with attempt "
f"{i + 1}/{retries}: {e}. Retrying..."
)
await asyncio.sleep(cooldown * backoff**i)
raise ValueError("retry count exceeded")
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
@@ -96,13 +137,15 @@ def get_web_loader(
urls: Union[str, Sequence[str]],
verify_ssl: bool = True,
requests_per_second: int = 2,
trust_env: bool = False,
):
# Check if the URLs are valid
safe_urls = safe_validate_urls([urls] if isinstance(urls, str) else urls)
return SafeWebBaseLoader(
safe_urls,
web_path=safe_urls,
verify_ssl=verify_ssl,
requests_per_second=requests_per_second,
continue_on_failure=True,
trust_env=trust_env
)