变更
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
from unittest import mock
|
||||
|
||||
from scrapy.settings import Settings
|
||||
|
||||
from scrapy_redis import defaults
|
||||
from scrapy_redis.connection import from_settings, get_redis, get_redis_from_settings
|
||||
|
||||
|
||||
class TestGetRedis:
|
||||
|
||||
def test_default_instance(self):
|
||||
server = get_redis()
|
||||
assert isinstance(server, defaults.REDIS_CLS)
|
||||
|
||||
def test_custom_class(self):
|
||||
client_cls = mock.Mock()
|
||||
server = get_redis(param="foo", redis_cls=client_cls)
|
||||
assert server is client_cls.return_value
|
||||
client_cls.assert_called_with(param="foo")
|
||||
|
||||
def test_from_url(self):
|
||||
client_cls = mock.Mock()
|
||||
url = "redis://localhost"
|
||||
server = get_redis(redis_cls=client_cls, url=url, param="foo")
|
||||
assert server is client_cls.from_url.return_value
|
||||
client_cls.from_url.assert_called_with(url, param="foo")
|
||||
|
||||
|
||||
class TestFromSettings:
|
||||
|
||||
def setup(self):
|
||||
self.redis_cls = mock.Mock()
|
||||
self.expected_params = {
|
||||
"timeout": 0,
|
||||
"flag": False,
|
||||
}
|
||||
self.settings = Settings(
|
||||
{
|
||||
"REDIS_PARAMS": dict(self.expected_params, redis_cls=self.redis_cls),
|
||||
}
|
||||
)
|
||||
|
||||
def test_redis_cls_default(self):
|
||||
server = from_settings(Settings())
|
||||
assert isinstance(server, defaults.REDIS_CLS)
|
||||
|
||||
def test_redis_cls_custom_path(self):
|
||||
self.settings["REDIS_PARAMS"]["redis_cls"] = "unittest.mock.Mock"
|
||||
server = from_settings(self.settings)
|
||||
assert isinstance(server, mock.Mock)
|
||||
|
||||
def test_default_params(self):
|
||||
server = from_settings(self.settings)
|
||||
assert server is self.redis_cls.return_value
|
||||
self.redis_cls.assert_called_with(
|
||||
**dict(defaults.REDIS_PARAMS, **self.expected_params)
|
||||
)
|
||||
|
||||
def test_override_default_params(self):
|
||||
for key, _ in defaults.REDIS_PARAMS.items():
|
||||
self.expected_params[key] = self.settings["REDIS_PARAMS"][key] = object()
|
||||
|
||||
server = from_settings(self.settings)
|
||||
assert server is self.redis_cls.return_value
|
||||
self.redis_cls.assert_called_with(**self.expected_params)
|
||||
|
||||
|
||||
def test_get_server_from_settings_alias():
|
||||
assert from_settings is get_redis_from_settings
|
||||
@@ -0,0 +1,108 @@
|
||||
from unittest import mock
|
||||
|
||||
from scrapy.http import Request
|
||||
from scrapy.settings import Settings
|
||||
|
||||
from scrapy_redis.dupefilter import RFPDupeFilter
|
||||
|
||||
|
||||
def get_redis_mock():
|
||||
server = mock.Mock()
|
||||
|
||||
def sadd(key, fp, added=0, db={}): # noqa: mutable db
|
||||
fingerprints = db.setdefault(key, set())
|
||||
if fp not in fingerprints:
|
||||
fingerprints.add(fp)
|
||||
added += 1
|
||||
return added
|
||||
|
||||
server.sadd = sadd
|
||||
|
||||
return server
|
||||
|
||||
|
||||
class TestRFPDupeFilter:
|
||||
|
||||
def setup(self):
|
||||
self.server = get_redis_mock()
|
||||
self.key = "dupefilter:1"
|
||||
self.df = RFPDupeFilter(self.server, self.key)
|
||||
|
||||
def test_request_seen(self):
|
||||
req = Request("http://example.com")
|
||||
|
||||
def same_request():
|
||||
assert not self.df.request_seen(req)
|
||||
assert self.df.request_seen(req)
|
||||
|
||||
def diff_method():
|
||||
diff_method = Request("http://example.com", method="POST")
|
||||
assert self.df.request_seen(req)
|
||||
assert not self.df.request_seen(diff_method)
|
||||
|
||||
def diff_url():
|
||||
diff_url = Request("http://example2.com")
|
||||
assert self.df.request_seen(req)
|
||||
assert not self.df.request_seen(diff_url)
|
||||
|
||||
same_request()
|
||||
diff_method()
|
||||
diff_url()
|
||||
|
||||
def test_overridable_request_fingerprinter(self):
|
||||
req = Request("http://example.com")
|
||||
self.df.request_fingerprint = mock.Mock(wraps=self.df.request_fingerprint)
|
||||
assert not self.df.request_seen(req)
|
||||
self.df.request_fingerprint.assert_called_with(req)
|
||||
|
||||
def test_clear_deletes(self):
|
||||
self.df.clear()
|
||||
self.server.delete.assert_called_with(self.key)
|
||||
|
||||
def test_close_calls_clear(self):
|
||||
self.df.clear = mock.Mock(wraps=self.df.clear)
|
||||
self.df.close()
|
||||
self.df.close(reason="foo")
|
||||
assert self.df.clear.call_count == 2
|
||||
|
||||
|
||||
def test_log_dupes():
|
||||
def _test(df, dupes, logcount):
|
||||
df.logger.debug = mock.Mock(wraps=df.logger.debug)
|
||||
for _ in range(dupes):
|
||||
req = Request("http://example")
|
||||
df.log(req, spider=mock.Mock())
|
||||
assert df.logger.debug.call_count == logcount
|
||||
|
||||
server = get_redis_mock()
|
||||
|
||||
df_quiet = RFPDupeFilter(server, "foo") # debug=False
|
||||
_test(df_quiet, 5, 1)
|
||||
|
||||
df_debug = RFPDupeFilter(server, "foo", debug=True)
|
||||
_test(df_debug, 5, 5)
|
||||
|
||||
|
||||
@mock.patch("scrapy_redis.dupefilter.get_redis_from_settings")
|
||||
class TestFromMethods:
|
||||
|
||||
def setup(self):
|
||||
self.settings = Settings(
|
||||
{
|
||||
"DUPEFILTER_DEBUG": True,
|
||||
}
|
||||
)
|
||||
|
||||
def test_from_settings(self, get_redis_from_settings):
|
||||
df = RFPDupeFilter.from_settings(self.settings)
|
||||
self.assert_dupefilter(df, get_redis_from_settings)
|
||||
|
||||
def test_from_crawler(self, get_redis_from_settings):
|
||||
crawler = mock.Mock(settings=self.settings)
|
||||
df = RFPDupeFilter.from_crawler(crawler)
|
||||
self.assert_dupefilter(df, get_redis_from_settings)
|
||||
|
||||
def assert_dupefilter(self, df, get_redis_from_settings):
|
||||
assert df.server is get_redis_from_settings.return_value
|
||||
assert df.key.startswith("dupefilter:")
|
||||
assert df.debug # true
|
||||
@@ -0,0 +1,7 @@
|
||||
import scrapy_redis
|
||||
|
||||
|
||||
def test_package_metadata():
|
||||
assert scrapy_redis.__author__
|
||||
assert scrapy_redis.__email__
|
||||
assert scrapy_redis.__version__
|
||||
@@ -0,0 +1,18 @@
|
||||
from scrapy_redis import picklecompat
|
||||
|
||||
|
||||
def test_picklecompat():
|
||||
obj = {
|
||||
"_encoding": "utf-8",
|
||||
"body": "",
|
||||
"callback": "_response_downloaded",
|
||||
"cookies": {},
|
||||
"dont_filter": False,
|
||||
"errback": None,
|
||||
"headers": {"Referer": ["http://www.dmoz.org/"]},
|
||||
"meta": {"depth": 1, "link_text": "Fran\xe7ais", "rule": 0},
|
||||
"method": "GET",
|
||||
"priority": 0,
|
||||
"url": "http://www.dmoz.org/World/Fran%C3%A7ais/",
|
||||
}
|
||||
assert obj == picklecompat.loads(picklecompat.dumps(obj))
|
||||
@@ -0,0 +1,38 @@
|
||||
from unittest import mock
|
||||
|
||||
from scrapy import Spider
|
||||
from scrapy.http import Request
|
||||
|
||||
from scrapy_redis.queue import Base
|
||||
|
||||
|
||||
class TestBaseQueue:
|
||||
|
||||
queue_cls = Base
|
||||
|
||||
def setup(self):
|
||||
self.server = mock.Mock()
|
||||
self.spider = Spider(name="foo")
|
||||
self.spider.parse_method = lambda x: x
|
||||
self.key = "key"
|
||||
self.q = self.queue_cls(self.server, self.spider, self.key)
|
||||
|
||||
def test_encode_decode_requests(self, q=None):
|
||||
if q is None:
|
||||
q = self.q
|
||||
req = Request(
|
||||
"http://example.com", callback=self.spider.parse, meta={"foo": "bar"}
|
||||
)
|
||||
out = q._decode_request(q._encode_request(req))
|
||||
assert req.url == out.url
|
||||
assert req.meta == out.meta
|
||||
assert req.callback == out.callback
|
||||
|
||||
def test_custom_serializer(self):
|
||||
serializer = mock.Mock()
|
||||
serializer.dumps = mock.Mock(side_effect=lambda x: x)
|
||||
serializer.loads = mock.Mock(side_effect=lambda x: x)
|
||||
q = Base(self.server, self.spider, self.key, serializer=serializer)
|
||||
self.test_encode_decode_requests(q)
|
||||
assert serializer.dumps.call_count == 1
|
||||
assert serializer.loads.call_count == 1
|
||||
@@ -0,0 +1,296 @@
|
||||
import os
|
||||
from unittest import TestCase, mock
|
||||
|
||||
import redis
|
||||
from scrapy import Request, Spider
|
||||
from scrapy.settings import Settings
|
||||
from scrapy.utils.test import get_crawler
|
||||
|
||||
from scrapy_redis import connection
|
||||
from scrapy_redis.dupefilter import RFPDupeFilter
|
||||
from scrapy_redis.queue import FifoQueue, LifoQueue, PriorityQueue
|
||||
from scrapy_redis.scheduler import Scheduler
|
||||
|
||||
# allow test settings from environment
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
|
||||
|
||||
def get_spider(*args, **kwargs):
|
||||
crawler = get_crawler(
|
||||
spidercls=kwargs.pop("spidercls", None),
|
||||
settings_dict=kwargs.pop("settings_dict", None),
|
||||
)
|
||||
return crawler._create_spider(*args, **kwargs)
|
||||
|
||||
|
||||
class RedisTestMixin:
|
||||
|
||||
@property
|
||||
def server(self):
|
||||
if not hasattr(self, "_redis"):
|
||||
self._redis = redis.Redis(REDIS_HOST, REDIS_PORT)
|
||||
return self._redis
|
||||
|
||||
def clear_keys(self, prefix):
|
||||
keys = self.server.keys(prefix + "*")
|
||||
if keys:
|
||||
self.server.delete(*keys)
|
||||
|
||||
|
||||
class DupeFilterTest(RedisTestMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.key = "scrapy_redis:tests:dupefilter:"
|
||||
self.df = RFPDupeFilter(self.server, self.key)
|
||||
|
||||
def tearDown(self):
|
||||
self.clear_keys(self.key)
|
||||
|
||||
def test_dupe_filter(self):
|
||||
req = Request("http://example.com")
|
||||
|
||||
self.assertFalse(self.df.request_seen(req))
|
||||
self.assertTrue(self.df.request_seen(req))
|
||||
|
||||
self.df.close("nothing")
|
||||
|
||||
|
||||
class QueueTestMixin(RedisTestMixin):
|
||||
|
||||
queue_cls = None
|
||||
|
||||
def setUp(self):
|
||||
self.spider = get_spider(name="myspider")
|
||||
self.key = f"scrapy_redis:tests:{self.spider.name}:queue"
|
||||
self.q = self.queue_cls(self.server, Spider("myspider"), self.key)
|
||||
|
||||
def tearDown(self):
|
||||
self.clear_keys(self.key)
|
||||
|
||||
def test_clear(self):
|
||||
self.assertEqual(len(self.q), 0)
|
||||
|
||||
for i in range(10):
|
||||
# XXX: can't use same url for all requests as SpiderPriorityQueue
|
||||
# uses redis' set implemention and we will end with only one
|
||||
# request in the set and thus failing the test. It should be noted
|
||||
# that when using SpiderPriorityQueue it acts as a request
|
||||
# duplication filter whenever the serielized requests are the same.
|
||||
# This might be unwanted on repetitive requests to the same page
|
||||
# even with dont_filter=True flag.
|
||||
req = Request(f"http://example.com/?page={i}")
|
||||
self.q.push(req)
|
||||
self.assertEqual(len(self.q), 10)
|
||||
|
||||
self.q.clear()
|
||||
self.assertEqual(len(self.q), 0)
|
||||
|
||||
|
||||
class FifoQueueTest(QueueTestMixin, TestCase):
|
||||
|
||||
queue_cls = FifoQueue
|
||||
|
||||
def test_queue(self):
|
||||
req1 = Request("http://example.com/page1")
|
||||
req2 = Request("http://example.com/page2")
|
||||
|
||||
self.q.push(req1)
|
||||
self.q.push(req2)
|
||||
|
||||
out1 = self.q.pop()
|
||||
out2 = self.q.pop(timeout=1)
|
||||
|
||||
self.assertEqual(out1.url, req1.url)
|
||||
self.assertEqual(out2.url, req2.url)
|
||||
|
||||
|
||||
class PriorityQueueTest(QueueTestMixin, TestCase):
|
||||
|
||||
queue_cls = PriorityQueue
|
||||
|
||||
def test_queue(self):
|
||||
req1 = Request("http://example.com/page1", priority=100)
|
||||
req2 = Request("http://example.com/page2", priority=50)
|
||||
req3 = Request("http://example.com/page2", priority=200)
|
||||
|
||||
self.q.push(req1)
|
||||
self.q.push(req2)
|
||||
self.q.push(req3)
|
||||
|
||||
out1 = self.q.pop()
|
||||
out2 = self.q.pop(timeout=0)
|
||||
out3 = self.q.pop(timeout=1)
|
||||
|
||||
self.assertEqual(out1.url, req3.url)
|
||||
self.assertEqual(out2.url, req1.url)
|
||||
self.assertEqual(out3.url, req2.url)
|
||||
|
||||
|
||||
class LifoQueueTest(QueueTestMixin, TestCase):
|
||||
|
||||
queue_cls = LifoQueue
|
||||
|
||||
def test_queue(self):
|
||||
req1 = Request("http://example.com/page1")
|
||||
req2 = Request("http://example.com/page2")
|
||||
|
||||
self.q.push(req1)
|
||||
self.q.push(req2)
|
||||
|
||||
out1 = self.q.pop()
|
||||
out2 = self.q.pop(timeout=1)
|
||||
|
||||
self.assertEqual(out1.url, req2.url)
|
||||
self.assertEqual(out2.url, req1.url)
|
||||
|
||||
|
||||
class SchedulerTest(RedisTestMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.key_prefix = "scrapy_redis:tests:"
|
||||
self.queue_key = self.key_prefix + "%(spider)s:requests"
|
||||
self.dupefilter_key = self.key_prefix + "%(spider)s:dupefilter"
|
||||
self.spider = get_spider(
|
||||
name="myspider",
|
||||
settings_dict={
|
||||
"REDIS_HOST": REDIS_HOST,
|
||||
"REDIS_PORT": REDIS_PORT,
|
||||
"SCHEDULER_QUEUE_KEY": self.queue_key,
|
||||
"SCHEDULER_DUPEFILTER_KEY": self.dupefilter_key,
|
||||
"SCHEDULER_FLUSH_ON_START": False,
|
||||
"SCHEDULER_PERSIST": False,
|
||||
"SCHEDULER_SERIALIZER": "pickle",
|
||||
"DUPEFILTER_CLASS": "scrapy_redis.dupefilter.RFPDupeFilter",
|
||||
},
|
||||
)
|
||||
self.scheduler = Scheduler.from_crawler(self.spider.crawler)
|
||||
|
||||
def tearDown(self):
|
||||
self.clear_keys(self.key_prefix)
|
||||
|
||||
def test_scheduler(self):
|
||||
# default no persist
|
||||
self.assertFalse(self.scheduler.persist)
|
||||
|
||||
self.scheduler.open(self.spider)
|
||||
self.assertEqual(len(self.scheduler), 0)
|
||||
|
||||
req = Request("http://example.com")
|
||||
self.scheduler.enqueue_request(req)
|
||||
self.assertTrue(self.scheduler.has_pending_requests())
|
||||
self.assertEqual(len(self.scheduler), 1)
|
||||
|
||||
# dupefilter in action
|
||||
self.scheduler.enqueue_request(req)
|
||||
self.assertEqual(len(self.scheduler), 1)
|
||||
|
||||
out = self.scheduler.next_request()
|
||||
self.assertEqual(out.url, req.url)
|
||||
|
||||
self.assertFalse(self.scheduler.has_pending_requests())
|
||||
self.assertEqual(len(self.scheduler), 0)
|
||||
|
||||
self.scheduler.close("finish")
|
||||
|
||||
def test_scheduler_persistent(self):
|
||||
# TODO: Improve this test to avoid the need to check for log messages.
|
||||
self.spider.log = mock.Mock(spec=self.spider.log)
|
||||
|
||||
self.scheduler.persist = True
|
||||
self.scheduler.open(self.spider)
|
||||
|
||||
self.assertEqual(self.spider.log.call_count, 0)
|
||||
|
||||
self.scheduler.enqueue_request(Request("http://example.com/page1"))
|
||||
self.scheduler.enqueue_request(Request("http://example.com/page2"))
|
||||
|
||||
self.assertTrue(self.scheduler.has_pending_requests())
|
||||
self.scheduler.close("finish")
|
||||
|
||||
self.scheduler.open(self.spider)
|
||||
self.spider.log.assert_has_calls(
|
||||
[
|
||||
mock.call("Resuming crawl (2 requests scheduled)"),
|
||||
]
|
||||
)
|
||||
self.assertEqual(len(self.scheduler), 2)
|
||||
|
||||
self.scheduler.persist = False
|
||||
self.scheduler.close("finish")
|
||||
|
||||
self.assertEqual(len(self.scheduler), 0)
|
||||
|
||||
|
||||
class ConnectionTest(TestCase):
|
||||
|
||||
# We can get a connection from just REDIS_URL.
|
||||
def test_redis_url(self):
|
||||
settings = Settings(
|
||||
{
|
||||
"REDIS_URL": "redis://foo:bar@localhost:9001/42",
|
||||
}
|
||||
)
|
||||
|
||||
server = connection.from_settings(settings)
|
||||
connect_args = server.connection_pool.connection_kwargs
|
||||
|
||||
self.assertEqual(connect_args["host"], "localhost")
|
||||
self.assertEqual(connect_args["port"], 9001)
|
||||
self.assertEqual(connect_args["password"], "bar")
|
||||
self.assertEqual(connect_args["db"], 42)
|
||||
|
||||
# We can get a connection from REDIS_HOST/REDIS_PORT.
|
||||
def test_redis_host_port(self):
|
||||
settings = Settings(
|
||||
{
|
||||
"REDIS_HOST": "localhost",
|
||||
"REDIS_PORT": 9001,
|
||||
}
|
||||
)
|
||||
|
||||
server = connection.from_settings(settings)
|
||||
connect_args = server.connection_pool.connection_kwargs
|
||||
|
||||
self.assertEqual(connect_args["host"], "localhost")
|
||||
self.assertEqual(connect_args["port"], 9001)
|
||||
|
||||
# REDIS_URL takes precedence over REDIS_HOST/REDIS_PORT.
|
||||
def test_redis_url_precedence(self):
|
||||
settings = Settings(
|
||||
{
|
||||
"REDIS_HOST": "baz",
|
||||
"REDIS_PORT": 1337,
|
||||
"REDIS_URL": "redis://foo:bar@localhost:9001/42",
|
||||
}
|
||||
)
|
||||
|
||||
server = connection.from_settings(settings)
|
||||
connect_args = server.connection_pool.connection_kwargs
|
||||
|
||||
self.assertEqual(connect_args["host"], "localhost")
|
||||
self.assertEqual(connect_args["port"], 9001)
|
||||
self.assertEqual(connect_args["password"], "bar")
|
||||
self.assertEqual(connect_args["db"], 42)
|
||||
|
||||
# We fallback to REDIS_HOST/REDIS_PORT if REDIS_URL is None.
|
||||
def test_redis_host_port_fallback(self):
|
||||
settings = Settings(
|
||||
{"REDIS_HOST": "baz", "REDIS_PORT": 1337, "REDIS_URL": None}
|
||||
)
|
||||
|
||||
server = connection.from_settings(settings)
|
||||
connect_args = server.connection_pool.connection_kwargs
|
||||
|
||||
self.assertEqual(connect_args["host"], "baz")
|
||||
self.assertEqual(connect_args["port"], 1337)
|
||||
|
||||
# We use default values for REDIS_HOST/REDIS_PORT.
|
||||
def test_redis_default(self):
|
||||
settings = Settings()
|
||||
|
||||
server = connection.from_settings(settings)
|
||||
connect_args = server.connection_pool.connection_kwargs
|
||||
|
||||
self.assertEqual(connect_args["host"], "localhost")
|
||||
self.assertEqual(connect_args["port"], 6379)
|
||||
@@ -0,0 +1,197 @@
|
||||
import contextlib
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from scrapy import signals
|
||||
from scrapy.exceptions import DontCloseSpider
|
||||
from scrapy.settings import Settings
|
||||
|
||||
from scrapy_redis.spiders import RedisCrawlSpider, RedisSpider
|
||||
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def flushall(server):
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
server.flushall()
|
||||
|
||||
|
||||
class MySpider(RedisSpider):
|
||||
name = "myspider"
|
||||
|
||||
|
||||
class MyCrawlSpider(RedisCrawlSpider):
|
||||
name = "myspider"
|
||||
|
||||
|
||||
def get_crawler(**kwargs):
|
||||
return mock.Mock(
|
||||
settings=Settings(
|
||||
{
|
||||
"REDIS_HOST": REDIS_HOST,
|
||||
"REDIS_PORT": REDIS_PORT,
|
||||
}
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisMixin_setup_redis:
|
||||
|
||||
def setup(self):
|
||||
self.myspider = MySpider()
|
||||
|
||||
def test_crawler_required(self):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.myspider.setup_redis()
|
||||
assert "crawler" in str(excinfo.value)
|
||||
|
||||
def test_requires_redis_key(self):
|
||||
self.myspider.crawler = get_crawler()
|
||||
self.myspider.redis_key = ""
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.myspider.setup_redis()
|
||||
assert "redis_key" in str(excinfo.value)
|
||||
|
||||
def test_invalid_batch_size(self):
|
||||
self.myspider.redis_batch_size = "x"
|
||||
self.myspider.crawler = get_crawler()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.myspider.setup_redis()
|
||||
assert "redis_batch_size" in str(excinfo.value)
|
||||
|
||||
def test_invalid_idle_time(self):
|
||||
self.myspider.max_idle_time = "x"
|
||||
self.myspider.crawler = get_crawler()
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
self.myspider.setup_redis()
|
||||
assert "max_idle_time" in str(excinfo.value)
|
||||
|
||||
@mock.patch("scrapy_redis.spiders.connection")
|
||||
def test_via_from_crawler(self, connection):
|
||||
server = connection.from_settings.return_value = mock.Mock()
|
||||
crawler = get_crawler()
|
||||
myspider = MySpider.from_crawler(crawler)
|
||||
assert myspider.server is server
|
||||
connection.from_settings.assert_called_with(crawler.settings)
|
||||
crawler.signals.connect.assert_called_with(
|
||||
myspider.spider_idle, signal=signals.spider_idle
|
||||
)
|
||||
# Second call does nothing.
|
||||
server = myspider.server
|
||||
crawler.signals.connect.reset_mock()
|
||||
myspider.setup_redis()
|
||||
assert myspider.server is server
|
||||
assert crawler.signals.connect.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spider_cls",
|
||||
[
|
||||
MySpider,
|
||||
MyCrawlSpider,
|
||||
],
|
||||
)
|
||||
def test_from_crawler_with_spider_arguments(spider_cls):
|
||||
crawler = get_crawler()
|
||||
spider = spider_cls.from_crawler(
|
||||
crawler,
|
||||
"foo",
|
||||
redis_key="key:%(name)s",
|
||||
redis_batch_size="2000",
|
||||
max_idle_time="100",
|
||||
)
|
||||
assert spider.name == "foo"
|
||||
assert spider.redis_key == "key:foo"
|
||||
assert spider.redis_batch_size == 2000
|
||||
assert spider.max_idle_time == 100
|
||||
|
||||
|
||||
class MockRequest(mock.Mock):
|
||||
def __init__(self, url, **kwargs):
|
||||
super().__init__()
|
||||
self.url = url
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.url == other.url
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.url)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}({self.url})>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"spider_cls",
|
||||
[
|
||||
MySpider,
|
||||
MyCrawlSpider,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("start_urls_as_zset", [False, True])
|
||||
@pytest.mark.parametrize("start_urls_as_set", [False, True])
|
||||
@mock.patch("scrapy.spiders.Request", MockRequest)
|
||||
def test_consume_urls_from_redis(start_urls_as_zset, start_urls_as_set, spider_cls):
|
||||
batch_size = 5
|
||||
redis_key = "start:urls"
|
||||
crawler = get_crawler()
|
||||
crawler.settings.setdict(
|
||||
{
|
||||
"REDIS_HOST": REDIS_HOST,
|
||||
"REDIS_PORT": REDIS_PORT,
|
||||
"REDIS_START_URLS_KEY": redis_key,
|
||||
"REDIS_START_URLS_AS_ZSET": start_urls_as_zset,
|
||||
"REDIS_START_URLS_AS_SET": start_urls_as_set,
|
||||
"CONCURRENT_REQUESTS": batch_size,
|
||||
}
|
||||
)
|
||||
spider = spider_cls.from_crawler(crawler)
|
||||
with flushall(spider.server):
|
||||
urls = [f"http://example.com/{i}" for i in range(batch_size * 2)]
|
||||
reqs = []
|
||||
if start_urls_as_set:
|
||||
server_put = spider.server.sadd
|
||||
elif start_urls_as_zset:
|
||||
|
||||
def server_put(key, value):
|
||||
spider.server.zadd(key, {value: 0})
|
||||
|
||||
else:
|
||||
server_put = spider.server.rpush
|
||||
for url in urls:
|
||||
server_put(redis_key, url)
|
||||
reqs.append(MockRequest(url))
|
||||
|
||||
# First call is to start requests.
|
||||
start_requests = list(spider.start_requests())
|
||||
if start_urls_as_zset or start_urls_as_set:
|
||||
assert len(start_requests) == batch_size
|
||||
assert {r.url for r in start_requests}.issubset(r.url for r in reqs)
|
||||
else:
|
||||
assert start_requests == reqs[:batch_size]
|
||||
|
||||
# Second call is to spider idle method.
|
||||
with pytest.raises(DontCloseSpider):
|
||||
spider.spider_idle()
|
||||
# Process remaining requests in the queue.
|
||||
with pytest.raises(DontCloseSpider):
|
||||
spider.spider_idle()
|
||||
|
||||
# Last batch was passed to crawl.
|
||||
assert crawler.engine.crawl.call_count == batch_size
|
||||
|
||||
if start_urls_as_zset or start_urls_as_set:
|
||||
crawler.engine.crawl.assert_has_calls(
|
||||
[mock.call(req) for req in reqs if req not in start_requests],
|
||||
any_order=True,
|
||||
)
|
||||
else:
|
||||
crawler.engine.crawl.assert_has_calls(
|
||||
[mock.call(req) for req in reqs[batch_size:]]
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
from scrapy_redis.utils import bytes_to_str
|
||||
|
||||
|
||||
def test_bytes_to_str():
|
||||
assert bytes_to_str(b"foo") == "foo"
|
||||
# This char is the same in bytes or latin1.
|
||||
assert bytes_to_str(b"\xc1", "latin1") == "\xc1"
|
||||
Reference in New Issue
Block a user