Location via proxy:   [ UP ]  
[Report a bug]   [Manage cookies]                
Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit cb4e46a

Browse files
authored
Merge pull request #302 from gnes-ai/rm-benchmark-client
refactor(client): remove benchmark client
2 parents a087626 + e588c94 commit cb4e46a

File tree

7 files changed

+18
-96
lines changed

7 files changed

+18
-96
lines changed

gnes/cli/api.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def client(args):
5858
return _client_http(args)
5959
elif args.client == 'cli':
6060
return _client_cli(args)
61-
elif args.client == 'benchmark':
62-
return _client_bm(args)
6361
else:
6462
raise ValueError('gnes client must follow with a client type from {http, cli, benchmark...}\n'
6563
'see "gnes client --help" for details')
@@ -94,11 +92,6 @@ def _client_cli(args):
9492
CLIClient(args)
9593

9694

97-
def _client_bm(args):
98-
from ..client.benchmark import BenchmarkClient
99-
BenchmarkClient(args)
100-
101-
10295
def compose(args):
10396
from ..composer.base import YamlComposer
10497
from ..composer.flask import YamlComposerFlask

gnes/cli/parser.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -365,21 +365,6 @@ def set_client_cli_parser(parser=None):
365365
return parser
366366

367367

368-
def set_client_benchmark_parser(parser=None):
369-
if not parser:
370-
parser = set_base_parser()
371-
_set_grpc_parser(parser)
372-
parser.add_argument('--batch_size', type=int, default=64,
373-
help='the size of the request to split')
374-
parser.add_argument('--request_length', type=int,
375-
default=1024,
376-
help='binary string length of each request')
377-
parser.add_argument('--num_requests', type=int,
378-
default=128,
379-
help='number of total requests')
380-
return parser
381-
382-
383368
def set_client_http_parser(parser=None):
384369
if not parser:
385370
parser = set_base_parser()
@@ -422,8 +407,6 @@ def get_main_parser():
422407
set_client_http_parser(
423408
spp.add_parser('http', help='start a client that allows HTTP requests as input', formatter_class=adf))
424409
set_client_cli_parser(spp.add_parser('cli', help='start a client that allows stdin as input', formatter_class=adf))
425-
set_client_benchmark_parser(
426-
spp.add_parser('benchmark', help='start a client for benchmark and unittest', formatter_class=adf))
427410

428411
# others
429412
set_composer_flask_parser(

gnes/client/benchmark.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

gnes/encoder/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def encode(self, text: List[str], *args, **kwargs) -> Union[Tuple, np.ndarray]:
4949

5050

5151
class BaseNumericEncoder(BaseEncoder):
52+
"""Note that all NumericEncoder can not be used as the first encoder of the pipeline"""
5253

5354
def encode(self, data: np.ndarray, *args, **kwargs) -> np.ndarray:
5455
pass

gnes/encoder/text/flair.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616

17-
from typing import List
17+
from typing import List, Tuple
1818

1919
import numpy as np
2020

@@ -25,16 +25,22 @@
2525
class FlairEncoder(BaseTextEncoder):
2626
is_trained = True
2727

28-
def __init__(self, pooling_strategy: str = 'mean', *args, **kwargs):
28+
def __init__(self,
29+
word_embedding: str = 'glove',
30+
flair_embeddings: Tuple[str] = ('news-forward', 'news-backward'),
31+
pooling_strategy: str = 'mean', *args, **kwargs):
2932
super().__init__(*args, **kwargs)
33+
34+
self.word_embedding = word_embedding
35+
self.flair_embeddings = flair_embeddings
3036
self.pooling_strategy = pooling_strategy
3137

3238
def post_init(self):
3339
from flair.embeddings import DocumentPoolEmbeddings, WordEmbeddings, FlairEmbeddings
3440
self._flair = DocumentPoolEmbeddings(
35-
[WordEmbeddings('glove'),
36-
FlairEmbeddings('news-forward'),
37-
FlairEmbeddings('news-backward')],
41+
[WordEmbeddings(self.word_embedding),
42+
FlairEmbeddings(self.flair_embeddings[0]),
43+
FlairEmbeddings(self.flair_embeddings[1])],
3844
pooling=self.pooling_strategy)
3945

4046
@batching

tests/test_flair_encoder.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@ def setUp(self):
1717
if line:
1818
self.test_str.append(line)
1919

20-
self.flair_encoder = FlairEncoder(
21-
model_name=os.environ.get('FLAIR_CI_MODEL'),
22-
pooling_strategy="REDUCE_MEAN")
20+
self.flair_encoder = FlairEncoder(model_name=os.environ.get('FLAIR_CI_MODEL'))
2321

2422
@unittest.SkipTest
2523
def test_encoding(self):
26-
vec = self.flair_encoder.encode(self.test_str)
27-
self.assertEqual(vec.shape[0], len(self.test_str))
28-
self.assertEqual(vec.shape[1], 512)
24+
vec = self.flair_encoder.encode(self.test_str[:2])
25+
print(vec.shape)
26+
self.assertEqual(vec.shape[0], 2)
27+
self.assertEqual(vec.shape[1], 4196)
2928

3029
@unittest.SkipTest
3130
def test_dump_load(self):

tests/test_stream_grpc.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44

55
import grpc
66

7-
from gnes.cli.parser import set_frontend_parser, set_router_parser, set_client_benchmark_parser
8-
from gnes.client.benchmark import BenchmarkClient
7+
from gnes.cli.parser import set_frontend_parser, set_router_parser
98
from gnes.helper import TimeContext
109
from gnes.proto import RequestGenerator, gnes_pb2_grpc
1110
from gnes.service.base import SocketType, MessageHandler, BaseService as BS
@@ -55,13 +54,6 @@ def test_bm_frontend(self):
5554
'--yaml_path', 'BaseRouter'
5655
])
5756

58-
b_args = set_client_benchmark_parser().parse_args([
59-
'--num_requests', '10',
60-
'--request_length', '65536'
61-
])
62-
with RouterService(p_args), FrontendService(args):
63-
BenchmarkClient(b_args)
64-
6557
def test_grpc_frontend(self):
6658
args = set_frontend_parser().parse_args([
6759
'--grpc_host', '127.0.0.1',

0 commit comments

Comments
 (0)